Пример #1
0
    def __init__(self,
                 classes,
                 num_rels,
                 mode='sgdet',
                 embed_dim=200,
                 pooling_dim=4096,
                 use_bias=True):

        super(EndCell, self).__init__()
        self.classes = classes
        self.num_rels = num_rels
        assert mode in MODES
        self.embed_dim = embed_dim
        self.pooling_dim = pooling_dim
        self.use_bias = use_bias
        self.mode = mode
        self.ort_embedding = torch.autograd.Variable(
            get_ort_embeds(self.num_classes, self.embed_dim).cuda())
        self.context = LC(classes=self.classes,
                          mode=self.mode,
                          embed_dim=self.embed_dim,
                          obj_dim=self.pooling_dim)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=7,
                                              stride=16,
                                              dim=512)
        self.pooling_size = 7

        roi_fmap = [
            Flattener(),
            load_vgg(use_dropout=False,
                     use_relu=False,
                     use_linear=pooling_dim == 4096,
                     pretrained=False).classifier,
        ]
        if pooling_dim != 4096:
            roi_fmap.append(nn.Linear(4096, pooling_dim))
        self.roi_fmap = nn.Sequential(*roi_fmap)
        self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        self.post_lstm = nn.Linear(self.pooling_dim + self.embed_dim + 5,
                                   self.pooling_dim * 2)

        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.

        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(
            0, 10.0 * math.sqrt(1.0 / self.pooling_dim))
        self.post_lstm.bias.data.zero_()

        self.post_emb = nn.Linear(self.pooling_dim + self.embed_dim + 5,
                                  self.pooling_dim * 2)

        self.rel_compress = nn.Linear(self.pooling_dim,
                                      self.num_rels,
                                      bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()
Пример #2
0
    def __init__(self,
                 train_data,
                 mode='sgdet',
                 num_gpus=1,
                 require_overlap_det=True,
                 use_bias=False,
                 test_bias=False,
                 detector_model='baseline',
                 RELS_PER_IMG=1024):
        """
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param require_overlap_det: Whether two objects must intersect
        """
        super(RelModelBase, self).__init__()
        self.classes = train_data.ind_to_classes
        self.rel_classes = train_data.ind_to_predicates
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode
        self.detector_model = detector_model
        self.RELS_PER_IMG = RELS_PER_IMG
        self.pooling_size = 7
        self.stride = 16
        self.obj_dim = 4096

        self.use_bias = use_bias
        self.test_bias = test_bias

        self.require_overlap = require_overlap_det and self.mode == 'sgdet'

        if self.detector_model == 'mrcnn':
            print('\nLoading COCO pretrained model maskrcnn_resnet50_fpn...\n')
            # See https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
            self.detector = torchvision.models.detection.maskrcnn_resnet50_fpn(
                pretrained=True,
                box_detections_per_img=50,
                box_score_thresh=0.2)
            in_features = self.detector.roi_heads.box_predictor.cls_score.in_features
            # replace the pre-trained head with a new one
            self.detector.roi_heads.box_predictor = FastRCNNPredictor(
                in_features, len(self.classes))
            self.detector.roi_heads.mask_predictor = None

        self.union_boxes = UnionBoxesAndFeats(
            pooling_size=self.pooling_size,
            stride=self.stride,
            dim=256 if self.detector_model == 'mrcnn' else 512)

        if self.detector_model == 'mrcnn':
            layers = list(self.detector.roi_heads.children())[:2]
            self.roi_fmap_obj = copy.deepcopy(layers[1])
            self.roi_fmap = copy.deepcopy(layers[1])
            self.multiscale_roi_pool = copy.deepcopy(layers[0])
        else:
            raise NotImplementedError(self.detector_model)

        if self.use_bias:
            self.freq_bias = FrequencyBias(train_data)
Пример #3
0
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=256,
                 pooling_dim=2048,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=False,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 limit_vision=True):
        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param use_vision: Whether to use vision in the final product
        :param require_overlap_det: Whether two objects must intersect
        :param embed_dim: Dimension for all embeddings
        :param hidden_dim: LSTM hidden size
        :param obj_dim:
        """
        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'
        self.hook_for_grad = False
        self.gradients = []

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels')
            if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )
        self.ort_embedding = torch.autograd.Variable(
            get_ort_embeds(self.num_classes, 200).cuda())
        embed_vecs = obj_edge_vectors(self.classes, wv_dim=self.embed_dim)
        self.obj_embed = nn.Embedding(self.num_classes, self.embed_dim)
        self.obj_embed.weight.data = embed_vecs.clone()

        # This probably doesn't help it much
        self.pos_embed = nn.Sequential(*[
            nn.BatchNorm1d(4, momentum=BATCHNORM_MOMENTUM / 10.0),
            nn.Linear(4, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
        ])

        self.context = LinearizedContext(
            self.classes,
            self.rel_classes,
            mode=self.mode,
            embed_dim=self.embed_dim,
            hidden_dim=self.hidden_dim,
            obj_dim=self.obj_dim,
            nl_obj=nl_obj,
            nl_edge=nl_edge,
            dropout_rate=rec_dropout,
            order=order,
            pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
            pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge)

        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512)

        self.merge_obj_feats = nn.Sequential(
            nn.Linear(self.obj_dim + self.embed_dim + 128, self.hidden_dim),
            nn.ReLU())

        # self.trans = nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim//4),
        #                             LayerNorm(self.hidden_dim//4), nn.ReLU(),
        #                             nn.Linear(self.hidden_dim//4, self.hidden_dim))

        self.get_phr_feats = nn.Linear(self.pooling_dim, self.hidden_dim)

        self.embeddings4lstm = nn.Embedding(self.num_classes, self.embed_dim)

        self.lstm = nn.LSTM(input_size=self.hidden_dim + self.embed_dim,
                            hidden_size=self.hidden_dim,
                            num_layers=1)

        self.obj_mps1 = Message_Passing4OBJ(self.hidden_dim)
        # self.obj_mps2 = Message_Passing4OBJ(self.hidden_dim)
        self.get_boxes_encode = Boxes_Encode(64)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False,
                         use_relu=False,
                         use_linear=pooling_dim == 4096,
                         pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        ###################################
        # self.obj_classify_head = nn.Linear(self.pooling_dim, self.num_classes)

        # self.post_emb_s = nn.Linear(self.pooling_dim, self.pooling_dim//2)
        # self.post_emb_s.weight = torch.nn.init.xavier_normal(self.post_emb_s.weight, gain=1.0)
        # self.post_emb_o = nn.Linear(self.pooling_dim, self.pooling_dim//2)
        # self.post_emb_o.weight = torch.nn.init.xavier_normal(self.post_emb_o.weight, gain=1.0)
        # self.merge_obj_high = nn.Linear(self.hidden_dim, self.pooling_dim//2)
        # self.merge_obj_high.weight = torch.nn.init.xavier_normal(self.merge_obj_high.weight, gain=1.0)
        # self.merge_obj_low = nn.Linear(self.pooling_dim + 5 + self.embed_dim, self.pooling_dim//2)
        # self.merge_obj_low.weight = torch.nn.init.xavier_normal(self.merge_obj_low.weight, gain=1.0)
        # self.rel_compress = nn.Linear(self.pooling_dim//2 + 64, self.num_rels, bias=True)
        # self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0)
        # self.freq_gate = nn.Linear(self.pooling_dim//2 + 64, self.num_rels, bias=True)
        # self.freq_gate.weight = torch.nn.init.xavier_normal(self.freq_gate.weight, gain=1.0)

        self.post_emb_s = nn.Linear(self.pooling_dim, self.pooling_dim)
        self.post_emb_s.weight = torch.nn.init.xavier_normal(
            self.post_emb_s.weight, gain=1.0)
        self.post_emb_o = nn.Linear(self.pooling_dim, self.pooling_dim)
        self.post_emb_o.weight = torch.nn.init.xavier_normal(
            self.post_emb_o.weight, gain=1.0)
        self.merge_obj_high = nn.Linear(self.hidden_dim, self.pooling_dim)
        self.merge_obj_high.weight = torch.nn.init.xavier_normal(
            self.merge_obj_high.weight, gain=1.0)
        self.merge_obj_low = nn.Linear(self.pooling_dim + 5 + self.embed_dim,
                                       self.pooling_dim)
        self.merge_obj_low.weight = torch.nn.init.xavier_normal(
            self.merge_obj_low.weight, gain=1.0)
        self.rel_compress = nn.Linear(self.pooling_dim + 64,
                                      self.num_rels,
                                      bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)
        self.freq_gate = nn.Linear(self.pooling_dim + 64,
                                   self.num_rels,
                                   bias=True)
        self.freq_gate.weight = torch.nn.init.xavier_normal(
            self.freq_gate.weight, gain=1.0)
        # self.ranking_module = nn.Sequential(nn.Linear(self.pooling_dim + 64, self.hidden_dim), nn.ReLU(), nn.Linear(self.hidden_dim, 1))
        if self.use_bias:
            self.freq_bias = FrequencyBias()
Пример #4
0
class RelModel(nn.Module):
    """
    RELATIONSHIPS
    """
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=256,
                 pooling_dim=2048,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=False,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 limit_vision=True):
        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param use_vision: Whether to use vision in the final product
        :param require_overlap_det: Whether two objects must intersect
        :param embed_dim: Dimension for all embeddings
        :param hidden_dim: LSTM hidden size
        :param obj_dim:
        """
        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'
        self.hook_for_grad = False
        self.gradients = []

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels')
            if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )
        self.ort_embedding = torch.autograd.Variable(
            get_ort_embeds(self.num_classes, 200).cuda())
        embed_vecs = obj_edge_vectors(self.classes, wv_dim=self.embed_dim)
        self.obj_embed = nn.Embedding(self.num_classes, self.embed_dim)
        self.obj_embed.weight.data = embed_vecs.clone()

        # This probably doesn't help it much
        self.pos_embed = nn.Sequential(*[
            nn.BatchNorm1d(4, momentum=BATCHNORM_MOMENTUM / 10.0),
            nn.Linear(4, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
        ])

        self.context = LinearizedContext(
            self.classes,
            self.rel_classes,
            mode=self.mode,
            embed_dim=self.embed_dim,
            hidden_dim=self.hidden_dim,
            obj_dim=self.obj_dim,
            nl_obj=nl_obj,
            nl_edge=nl_edge,
            dropout_rate=rec_dropout,
            order=order,
            pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
            pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge)

        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512)

        self.merge_obj_feats = nn.Sequential(
            nn.Linear(self.obj_dim + self.embed_dim + 128, self.hidden_dim),
            nn.ReLU())

        # self.trans = nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim//4),
        #                             LayerNorm(self.hidden_dim//4), nn.ReLU(),
        #                             nn.Linear(self.hidden_dim//4, self.hidden_dim))

        self.get_phr_feats = nn.Linear(self.pooling_dim, self.hidden_dim)

        self.embeddings4lstm = nn.Embedding(self.num_classes, self.embed_dim)

        self.lstm = nn.LSTM(input_size=self.hidden_dim + self.embed_dim,
                            hidden_size=self.hidden_dim,
                            num_layers=1)

        self.obj_mps1 = Message_Passing4OBJ(self.hidden_dim)
        # self.obj_mps2 = Message_Passing4OBJ(self.hidden_dim)
        self.get_boxes_encode = Boxes_Encode(64)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False,
                         use_relu=False,
                         use_linear=pooling_dim == 4096,
                         pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        ###################################
        # self.obj_classify_head = nn.Linear(self.pooling_dim, self.num_classes)

        # self.post_emb_s = nn.Linear(self.pooling_dim, self.pooling_dim//2)
        # self.post_emb_s.weight = torch.nn.init.xavier_normal(self.post_emb_s.weight, gain=1.0)
        # self.post_emb_o = nn.Linear(self.pooling_dim, self.pooling_dim//2)
        # self.post_emb_o.weight = torch.nn.init.xavier_normal(self.post_emb_o.weight, gain=1.0)
        # self.merge_obj_high = nn.Linear(self.hidden_dim, self.pooling_dim//2)
        # self.merge_obj_high.weight = torch.nn.init.xavier_normal(self.merge_obj_high.weight, gain=1.0)
        # self.merge_obj_low = nn.Linear(self.pooling_dim + 5 + self.embed_dim, self.pooling_dim//2)
        # self.merge_obj_low.weight = torch.nn.init.xavier_normal(self.merge_obj_low.weight, gain=1.0)
        # self.rel_compress = nn.Linear(self.pooling_dim//2 + 64, self.num_rels, bias=True)
        # self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0)
        # self.freq_gate = nn.Linear(self.pooling_dim//2 + 64, self.num_rels, bias=True)
        # self.freq_gate.weight = torch.nn.init.xavier_normal(self.freq_gate.weight, gain=1.0)

        self.post_emb_s = nn.Linear(self.pooling_dim, self.pooling_dim)
        self.post_emb_s.weight = torch.nn.init.xavier_normal(
            self.post_emb_s.weight, gain=1.0)
        self.post_emb_o = nn.Linear(self.pooling_dim, self.pooling_dim)
        self.post_emb_o.weight = torch.nn.init.xavier_normal(
            self.post_emb_o.weight, gain=1.0)
        self.merge_obj_high = nn.Linear(self.hidden_dim, self.pooling_dim)
        self.merge_obj_high.weight = torch.nn.init.xavier_normal(
            self.merge_obj_high.weight, gain=1.0)
        self.merge_obj_low = nn.Linear(self.pooling_dim + 5 + self.embed_dim,
                                       self.pooling_dim)
        self.merge_obj_low.weight = torch.nn.init.xavier_normal(
            self.merge_obj_low.weight, gain=1.0)
        self.rel_compress = nn.Linear(self.pooling_dim + 64,
                                      self.num_rels,
                                      bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)
        self.freq_gate = nn.Linear(self.pooling_dim + 64,
                                   self.num_rels,
                                   bias=True)
        self.freq_gate.weight = torch.nn.init.xavier_normal(
            self.freq_gate.weight, gain=1.0)
        # self.ranking_module = nn.Sequential(nn.Linear(self.pooling_dim + 64, self.hidden_dim), nn.ReLU(), nn.Linear(self.hidden_dim, 1))
        if self.use_bias:
            self.freq_bias = FrequencyBias()

    @property
    def num_classes(self):
        return len(self.classes)

    @property
    def num_rels(self):
        return len(self.rel_classes)

    # def fixed_obj_modules(self):
    #     for p in self.detector.parameters():
    #         p.requires_grad = False
    #     for p in self.obj_embed.parameters():
    #         p.requires_grad = False
    #     for p in self.pos_embed.parameters():
    #         p.requires_grad = False
    #     for p in self.context.parameters():
    #         p.requires_grad = False
    #     for p in self.union_boxes.parameters():
    #         p.requires_grad = False
    #     for p in self.merge_obj_feats.parameters():
    #         p.requires_grad = False
    #     for p in self.get_phr_feats.parameters():
    #         p.requires_grad = False
    #     for p in self.embeddings4lstm.parameters():
    #         p.requires_grad = False
    #     for p in self.lstm.parameters():
    #         p.requires_grad = False
    #     for p in self.obj_mps1.parameters():
    #         p.requires_grad = False
    #     for p in self.roi_fmap_obj.parameters():
    #         p.requires_grad = False
    #     for p in self.roi_fmap.parameters():
    #         p.requires_grad = False

    def save_grad(self, grad):
        self.gradients.append(grad)

    def visual_rep(self, features, rois, pair_inds):
        """
        Classify the features
        :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4]
        :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1].
        :param pair_inds inds to use when predicting
        :return: score_pred, a [num_rois, num_classes] array
                 box_pred, a [num_rois, num_classes, 4] array
        """
        assert pair_inds.size(1) == 2
        uboxes = self.union_boxes(features, rois, pair_inds)
        return self.roi_fmap(uboxes)

    def visual_obj(self, features, rois, pair_inds):
        assert pair_inds.size(1) == 2
        uboxes = self.union_boxes(features, rois, pair_inds)
        return uboxes

    def get_rel_inds(self, rel_labels, im_inds, box_priors):
        # Get the relationship candidates
        if self.training:
            rel_inds = rel_labels[:, :3].data.clone()
        else:
            rel_cands = im_inds.data[:, None] == im_inds.data[None]
            rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0

            # Require overlap for detection
            if self.require_overlap:
                rel_cands = rel_cands & (bbox_overlaps(box_priors.data,
                                                       box_priors.data) > 0)

                # if there are fewer then 100 things then we might as well add some?
                amt_to_add = 100 - rel_cands.long().sum()

            rel_cands = rel_cands.nonzero()
            if rel_cands.dim() == 0:
                rel_cands = im_inds.data.new(1, 2).fill_(0)

            rel_inds = torch.cat(
                (im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1)
        return rel_inds

    def union_pairs(self, im_inds):
        rel_cands = im_inds.data[:, None] == im_inds.data[None]
        rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0
        rel_inds = rel_cands.nonzero()
        rel_inds = torch.cat((im_inds[rel_inds[:, 0]][:, None].data, rel_inds),
                             -1)
        return rel_inds

    def obj_feature_map(self, features, rois):
        """
        Gets the ROI features
        :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2)
        :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1].
        :return: [num_rois, #dim] array
        """
        feature_pool = RoIAlignFunction(self.pooling_size,
                                        self.pooling_size,
                                        spatial_scale=1 / 16)(features, rois)
        return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1))

    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels

            if test:
            prob dists, boxes, img inds, maxscores, classes

        """
        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)
        # rel_feat = self.relationship_feat.feature_map(x)

        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)
        spt_feats = self.get_boxes_encode(boxes, rel_inds)
        pair_inds = self.union_pairs(im_inds)

        if self.hook_for_grad:
            rel_inds = gt_rels[:, :-1].data

        if self.hook_for_grad:
            fmap = result.fmap
            fmap.register_hook(self.save_grad)
        else:
            fmap = result.fmap.detach()

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        result.obj_fmap = self.obj_feature_map(fmap, rois)
        # result.obj_dists_head = self.obj_classify_head(obj_fmap_rel)

        obj_embed = F.softmax(result.rm_obj_dists,
                              dim=1) @ self.obj_embed.weight
        obj_embed_lstm = F.softmax(result.rm_obj_dists,
                                   dim=1) @ self.embeddings4lstm.weight
        pos_embed = self.pos_embed(Variable(center_size(boxes.data)))
        obj_pre_rep = torch.cat((result.obj_fmap, obj_embed, pos_embed), 1)
        obj_feats = self.merge_obj_feats(obj_pre_rep)
        # obj_feats=self.trans(obj_feats)
        obj_feats_lstm = torch.cat(
            (obj_feats, obj_embed_lstm),
            -1).contiguous().view(1, obj_feats.size(0), -1)

        # obj_feats = F.relu(obj_feats)

        phr_ori = self.visual_rep(fmap, rois, pair_inds[:, 1:])
        vr_indices = torch.from_numpy(
            intersect_2d(rel_inds[:, 1:].cpu().numpy(),
                         pair_inds[:, 1:].cpu().numpy()).astype(
                             np.uint8)).cuda().max(-1)[1]
        vr = phr_ori[vr_indices]

        phr_feats_high = self.get_phr_feats(phr_ori)

        obj_feats_lstm_output, (obj_hidden_states,
                                obj_cell_states) = self.lstm(obj_feats_lstm)

        rm_obj_dists1 = result.rm_obj_dists + self.context.decoder_lin(
            obj_feats_lstm_output.squeeze())
        obj_feats_output = self.obj_mps1(obj_feats_lstm_output.view(-1, obj_feats_lstm_output.size(-1)), \
                            phr_feats_high, im_inds, pair_inds)

        obj_embed_lstm1 = F.softmax(rm_obj_dists1,
                                    dim=1) @ self.embeddings4lstm.weight

        obj_feats_lstm1 = torch.cat((obj_feats_output, obj_embed_lstm1), -1).contiguous().view(1, \
                            obj_feats_output.size(0), -1)
        obj_feats_lstm_output, _ = self.lstm(
            obj_feats_lstm1, (obj_hidden_states, obj_cell_states))

        rm_obj_dists2 = rm_obj_dists1 + self.context.decoder_lin(
            obj_feats_lstm_output.squeeze())
        obj_feats_output = self.obj_mps1(obj_feats_lstm_output.view(-1, obj_feats_lstm_output.size(-1)), \
                            phr_feats_high, im_inds, pair_inds)

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds = self.context(
            rm_obj_dists2, obj_feats_output, result.rm_obj_labels
            if self.training or self.mode == 'predcls' else None, boxes.data,
            result.boxes_all)

        obj_dtype = result.obj_fmap.data.type()
        obj_preds_embeds = torch.index_select(self.ort_embedding, 0,
                                              result.obj_preds).type(obj_dtype)
        tranfered_boxes = torch.stack(
            (boxes[:, 0] / IM_SCALE, boxes[:, 3] / IM_SCALE,
             boxes[:, 2] / IM_SCALE, boxes[:, 1] / IM_SCALE,
             ((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])) /
             (IM_SCALE**2)), -1).type(obj_dtype)
        obj_features = torch.cat(
            (result.obj_fmap, obj_preds_embeds, tranfered_boxes), -1)
        obj_features_merge = self.merge_obj_low(
            obj_features) + self.merge_obj_high(obj_feats_output)

        # Split into subject and object representations
        result.subj_rep = self.post_emb_s(obj_features_merge)[rel_inds[:, 1]]
        result.obj_rep = self.post_emb_o(obj_features_merge)[rel_inds[:, 2]]
        prod_rep = result.subj_rep * result.obj_rep

        # obj_pools = self.visual_obj(result.fmap.detach(), rois, rel_inds[:, 1:])
        # rel_pools = self.relationship_feat.union_rel_pooling(rel_feat, rois, rel_inds[:, 1:])
        # context_pools = torch.cat([obj_pools, rel_pools], 1)
        # merge_pool = self.merge_feat(context_pools)
        # vr = self.roi_fmap(merge_pool)

        # vr = self.rel_refine(vr)

        prod_rep = prod_rep * vr

        if self.use_tanh:
            prod_rep = F.tanh(prod_rep)

        prod_rep = torch.cat((prod_rep, spt_feats), -1)
        freq_gate = self.freq_gate(prod_rep)
        freq_gate = F.sigmoid(freq_gate)
        result.rel_dists = self.rel_compress(prod_rep)
        # result.rank_factor = self.ranking_module(prod_rep).view(-1)

        if self.use_bias:
            result.rel_dists = result.rel_dists + freq_gate * self.freq_bias.index_with_labels(
                torch.stack((
                    result.obj_preds[rel_inds[:, 1]],
                    result.obj_preds[rel_inds[:, 2]],
                ), 1))

        if self.training:
            return result

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        # rel_rep = smooth_one_hot(rel_rep)
        # rank_factor = F.sigmoid(result.rank_factor)

        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)

    def __getitem__(self, batch):
        """ Hack to do multi-GPU training"""
        batch.scatter()
        if self.num_gpus == 1:
            return self(*batch[0])

        replicas = nn.parallel.replicate(self,
                                         devices=list(range(self.num_gpus)))
        outputs = nn.parallel.parallel_apply(
            replicas, [batch[i] for i in range(self.num_gpus)])

        if self.training:
            return gather_res(outputs, 0, dim=0)
        return outputs
Пример #5
0
class RelModel(nn.Module):
    """
    RELATIONSHIPS
    """
    def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True,
                 embed_dim=200, hidden_dim=256, pooling_dim=2048,
                 nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01,
                 use_proposals=False, pass_in_obj_feats_to_decoder=True,
                 pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True,
                 limit_vision=True):

        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param use_vision: Whether to use vision in the final product
        :param require_overlap_det: Whether two objects must intersect
        :param embed_dim: Dimension for all embeddings
        :param hidden_dim: LSTM hidden size
        :param obj_dim:
        """
        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision=limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )

        self.context = LinearizedContext(self.classes, self.rel_classes, mode=self.mode,
                                         embed_dim=self.embed_dim, hidden_dim=self.hidden_dim,
                                         obj_dim=self.obj_dim,
                                         nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout,
                                         order=order,
                                         pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
                                         pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge)

        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16,
                                              dim=1024 if use_resnet else 512)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        ###################################
        self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2)

        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.

        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(0, 10.0 * math.sqrt(1.0 / self.hidden_dim))
        self.post_lstm.bias.data.zero_()

        if nl_edge == 0:
            self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim*2)
            self.post_emb.weight.data.normal_(0, math.sqrt(1.0))

        self.rel_compress = nn.Linear(self.pooling_dim, self.num_rels, bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()
        
        # not too large; because in the same img, rel class is mostly 0; if too large, most neg rel is repeated
        self.neg_num = 1


        """
        self.embdim = 100 
        self.obj1_fc= nn.Sequential(
            nn.BatchNorm1d(4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, self.num_classes * self.embdim, bias=True),
            nn.BatchNorm1d(self.num_classes * self.embdim),
            nn.ReLU(inplace=True),
        )
        self.obj2_fc= nn.Sequential(
            nn.BatchNorm1d(4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, self.num_classes * self.embdim, bias=True),
            nn.BatchNorm1d(self.num_classes * self.embdim),
            nn.ReLU(inplace=True),
        )
        self.rel_seq = nn.Sequential(
            nn.BatchNorm1d(4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, self.num_rels * self.embdim, bias=True),
            nn.BatchNorm1d(self.num_rels * self.embdim),
            nn.ReLU(inplace=True),
        )
        #self.new_roi_fmap_obj = load_vgg(pretrained=False).classifier
        """

    @property
    def num_classes(self):
        return len(self.classes)

    @property
    def num_rels(self):
        return len(self.rel_classes)

    def visual_rep(self, features, rois, pair_inds):
        """
        Classify the features
        :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4]
        :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1].
        :param pair_inds inds to use when predicting
        :return: score_pred, a [num_rois, num_classes] array
                 box_pred, a [num_rois, num_classes, 4] array
        """
        assert pair_inds.size(1) == 2
        uboxes = self.union_boxes(features, rois, pair_inds)
        return self.roi_fmap(uboxes)

    def get_rel_inds(self, rel_labels, im_inds, box_priors):
        # Get the relationship candidates
        if self.training:
            rel_inds = rel_labels[:, :3].data.clone()
        else:
            rel_cands = im_inds.data[:, None] == im_inds.data[None]
            rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0

            # Require overlap for detection
            if self.require_overlap:
                rel_cands = rel_cands & (bbox_overlaps(box_priors.data,
                                                       box_priors.data) > 0)

                # if there are fewer then 100 things then we might as well add some?
                amt_to_add = 100 - rel_cands.long().sum()

            rel_cands = rel_cands.nonzero()
            if rel_cands.dim() == 0:
                rel_cands = im_inds.data.new(1, 2).fill_(0)

            rel_inds = torch.cat((im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1)
        return rel_inds

    def obj_feature_map(self, features, rois):
        """
        Gets the ROI features
        :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2)
        :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1].
        :return: [num_rois, #dim] array
        """
        feature_pool = RoIAlignFunction(self.pooling_size, self.pooling_size, spatial_scale=1 / 16)(
            features, rois)
        return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1))  # vgg.classifier
    def new_obj_feature_map(self, features, rois):
        """
        Gets the ROI features
        :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2)
        :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1].
        :return: [num_rois, #dim] array
        """
        feature_pool = RoIAlignFunction(self.pooling_size, self.pooling_size, spatial_scale=1 / 16)(
            features, rois)
        return self.new_roi_fmap_obj(feature_pool.view(rois.size(0), -1))  # vgg.classifier

    def get_neg_examples(self, rel_labels):
        """
        Given relationship combination (positive examples), return the negative examples.
        :param rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type)
        :return: neg_rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type)
        """
        neg_rel_labels = []

        num_im = rel_labels.data[:,0].max()+1
        im_inds = rel_labels.data.cpu().numpy()[:,0]
        rel_type = rel_labels.data.cpu().numpy()[:,3]
        box_pairs = rel_labels.data.cpu().numpy()[:,:3]

        for im_ind in range(num_im):

            pred_ind = np.where(im_inds == im_ind)[0]
         
            rel_type_i = rel_type[pred_ind]

            rel_labels_i = box_pairs[pred_ind][:,None,:]
            row_num = rel_labels_i.shape[0]
            rel_labels_i = torch.LongTensor(rel_labels_i).expand_as(torch.Tensor(row_num, self.neg_num, 3))
            neg_pairs_i = rel_labels_i.contiguous().view(-1, 3).cpu().numpy()

            neg_rel_type_i = np.zeros(self.neg_num)

            for k in range(rel_type_i.shape[0]):

                neg_rel_type_k = np.delete(rel_type_i, np.where(rel_type_i == rel_type_i[k])[0]) # delete same rel class
                #assert neg_rel_type_k.shape[0] != 0
                if neg_rel_type_k.shape[0] != 0: 
                    neg_rel_type_k = np.random.choice(neg_rel_type_k, size=self.neg_num, replace=True)
                    neg_rel_type_i = np.concatenate((neg_rel_type_i,neg_rel_type_k),axis=0)
                else:
                    orig_cls = np.arange(self.num_rels)
                    cls_pool = np.delete(orig_cls, np.where( orig_cls == rel_type_i[k] )[0])
                    neg_rel_type_k = np.random.choice(cls_pool, size=self.neg_num, replace=False)
                    neg_rel_type_i = np.concatenate((neg_rel_type_i,neg_rel_type_k),axis=0) 

            neg_rel_type_i = np.delete(neg_rel_type_i, np.arange(self.neg_num))  # delete the first few rows
            assert neg_pairs_i.shape[0] == neg_rel_type_i.shape[0]
            neg_rel_labels.append(np.column_stack((neg_pairs_i,neg_rel_type_i)))

        neg_rel_labels = torch.LongTensor(np.concatenate(np.array(neg_rel_labels), 0))
        neg_rel_labels = neg_rel_labels.cuda(rel_labels.get_device(), async=True)

        return neg_rel_labels

    def forward(self, x, im_sizes, image_offset,
                gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """

        # Detector
        result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
                               train_anchor_inds, return_fmap=True)
        if result.is_none():
            return ValueError("heck")
        im_inds = result.im_inds - image_offset
        # boxes: [#boxes, 4], without box deltas; where narrow error comes from, should .detach()
        boxes = result.rm_box_priors.detach()   

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet' # sgcls's result.rel_labels is gt and not None
            # rel_labels: [num_rels, 4] (img ind, box0 ind, box1ind, rel type)
            result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data,
                                                gt_boxes.data, gt_classes.data, gt_rels.data,
                                                image_offset, filter_non_overlap=True,
                                                num_sample_per_gt=1)
            rel_labels_neg = self.get_neg_examples(result.rel_labels)
            rel_inds_neg = rel_labels_neg[:,:3]

        #torch.cat((result.rel_labels[:,0].contiguous().view(236,1),result.rm_obj_labels[result.rel_labels[:,1]].view(236,1),result.rm_obj_labels[result.rel_labels[:,2]].view(236,1),result.rel_labels[:,3].contiguous().view(236,1)),-1)
        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)  #[275,3], [im_inds, box1_inds, box2_inds]

        # rois: [#boxes, 5]
        rois = torch.cat((im_inds[:, None].float(), boxes), 1)
        # result.rm_obj_fmap: [384, 4096]
        #result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois) # detach: prevent backforward flowing
        result.rm_obj_fmap = self.obj_feature_map(result.fmap.detach(), rois.detach()) # detach: prevent backforward flowing

        # BiLSTM
        result.rm_obj_dists, result.rm_obj_preds, edge_ctx = self.context(
            result.rm_obj_fmap,   # has been detached above
            # rm_obj_dists: [#boxes, 151]; Prevent gradients from flowing back into score_fc from elsewhere
            result.rm_obj_dists.detach(),  # .detach:Returns a new Variable, detached from the current graph
            im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None,
            boxes.data, result.boxes_all.detach() if self.mode == 'sgdet' else result.boxes_all)
        

        # Post Processing
        # nl_egde <= 0
        if edge_ctx is None:
            edge_rep = self.post_emb(result.rm_obj_preds)
        # nl_edge > 0
        else: 
            edge_rep = self.post_lstm(edge_ctx)  # [384, 4096*2]
     
        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)  #[384,2,4096]
        subj_rep = edge_rep[:, 0]  # [384,4096]
        obj_rep = edge_rep[:, 1]  # [384,4096]
        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]  # prod_rep, rel_inds: [275,4096], [275,3]
    

        if self.use_vision: # True when sgdet
            # union rois: fmap.detach--RoIAlignFunction--roifmap--vr [275,4096]
            vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])

            if self.limit_vision:  # False when sgdet
                # exact value TBD
                prod_rep = torch.cat((prod_rep[:,:2048] * vr[:,:2048], prod_rep[:,2048:]), 1) 
            else:
                prod_rep = prod_rep * vr  # [275,4096]
                if self.training:
                    vr_neg = self.visual_rep(result.fmap.detach(), rois, rel_inds_neg[:, 1:])
                    prod_rep_neg = subj_rep[rel_inds_neg[:, 1]].detach() * obj_rep[rel_inds_neg[:, 2]].detach() * vr_neg 
                    rel_dists_neg = self.rel_compress(prod_rep_neg)
                    

        if self.use_tanh:  # False when sgdet
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)  # [275,51]

        if self.use_bias:  # True when sgdet
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(torch.stack((
                result.rm_obj_preds[rel_inds[:, 1]],
                result.rm_obj_preds[rel_inds[:, 2]],
            ), 1))


        if self.training:
            judge = result.rel_labels.data[:,3] != 0
            if judge.sum() != 0:  # gt_rel exit in rel_inds
                select_rel_inds = torch.arange(rel_inds.size(0)).view(-1,1).long().cuda()[result.rel_labels.data[:,3] != 0]
                com_rel_inds = rel_inds[select_rel_inds]
                twod_inds = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data
                result.obj_scores = F.softmax(result.rm_obj_dists.detach(), dim=1).view(-1)[twod_inds]   # only 1/4 of 384 obj_dists will be updated; because only 1/4 objs's labels are not 0

                # positive overall score
                obj_scores0 = result.obj_scores[com_rel_inds[:,1]]
                obj_scores1 = result.obj_scores[com_rel_inds[:,2]]
                rel_rep = F.softmax(result.rel_dists[select_rel_inds], dim=1)    # result.rel_dists has grad
                _, pred_classes_argmax = rel_rep.data[:,:].max(1)  # all classes
                max_rel_score = rel_rep.gather(1, Variable(pred_classes_argmax.view(-1,1))).squeeze()  # SqueezeBackward, GatherBackward
                score_list = torch.cat((com_rel_inds[:,0].float().contiguous().view(-1,1), obj_scores0.data.view(-1,1), obj_scores1.data.view(-1,1), max_rel_score.data.view(-1,1)), 1)
                prob_score = max_rel_score * obj_scores0.detach() * obj_scores1.detach()
                #pos_prob[:,1][result.rel_labels.data[:,3] == 0] = 0  # treat most rel_labels as neg because their rel cls is 0 "unknown"  
                
                # negative overall score
                obj_scores0_neg = result.obj_scores[rel_inds_neg[:,1]]
                obj_scores1_neg = result.obj_scores[rel_inds_neg[:,2]]
                rel_rep_neg = F.softmax(rel_dists_neg, dim=1)   # rel_dists_neg has grad
                _, pred_classes_argmax_neg = rel_rep_neg.data[:,:].max(1)  # all classes
                max_rel_score_neg = rel_rep_neg.gather(1, Variable(pred_classes_argmax_neg.view(-1,1))).squeeze() # SqueezeBackward, GatherBackward
                score_list_neg = torch.cat((rel_inds_neg[:,0].float().contiguous().view(-1,1), obj_scores0_neg.data.view(-1,1), obj_scores1_neg.data.view(-1,1), max_rel_score_neg.data.view(-1,1)), 1)
                prob_score_neg = max_rel_score_neg * obj_scores0_neg.detach() * obj_scores1_neg.detach()

                # use all rel_inds, already irrelavant with im_inds, which is only use to extract region from img and produce rel_inds
                # 384 boxes---(rel_inds)(rel_inds_neg)--->prob_score,prob_score_neg 
                all_rel_inds = torch.cat((result.rel_labels.data[select_rel_inds], rel_labels_neg), 0)  # [#pos_inds+#neg_inds, 4]
                flag = torch.cat((torch.ones(prob_score.size(0),1).cuda(),torch.zeros(prob_score_neg.size(0),1).cuda()),0)
                score_list_all = torch.cat((score_list,score_list_neg), 0) 
                all_prob = torch.cat((prob_score,prob_score_neg), 0)  # Variable, [#pos_inds+#neg_inds, 1]

                _, sort_prob_inds = torch.sort(all_prob.data, dim=0, descending=True)

                sorted_rel_inds = all_rel_inds[sort_prob_inds]
                sorted_flag = flag[sort_prob_inds].squeeze()  # can be used to check distribution of pos and neg
                sorted_score_list_all = score_list_all[sort_prob_inds]
                sorted_all_prob = all_prob[sort_prob_inds]  # Variable
                
                # positive triplet and score list
                pos_sorted_inds = sorted_rel_inds.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 1).view(-1,4)
                pos_trips = torch.cat((pos_sorted_inds[:,0].contiguous().view(-1,1), result.rm_obj_labels.data.view(-1,1)[pos_sorted_inds[:,1]], result.rm_obj_labels.data.view(-1,1)[pos_sorted_inds[:,2]], pos_sorted_inds[:,3].contiguous().view(-1,1)), 1)
                pos_score_list = sorted_score_list_all.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 1).view(-1,4)
                pos_exp = sorted_all_prob[sorted_flag == 1]  # Variable 

                # negative triplet and score list
                neg_sorted_inds = sorted_rel_inds.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 0).view(-1,4)
                neg_trips = torch.cat((neg_sorted_inds[:,0].contiguous().view(-1,1), result.rm_obj_labels.data.view(-1,1)[neg_sorted_inds[:,1]], result.rm_obj_labels.data.view(-1,1)[neg_sorted_inds[:,2]], neg_sorted_inds[:,3].contiguous().view(-1,1)), 1)
                neg_score_list = sorted_score_list_all.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 0).view(-1,4)
                neg_exp = sorted_all_prob[sorted_flag == 0]  # Variable
                
                
                int_part = neg_exp.size(0) // pos_exp.size(0)
                decimal_part = neg_exp.size(0) % pos_exp.size(0)
                int_inds = torch.arange(pos_exp.size(0))[:,None].expand_as(torch.Tensor(pos_exp.size(0), int_part)).contiguous().view(-1)
                int_part_inds = (int(pos_exp.size(0) -1) - int_inds).long().cuda() # use minimum pos to correspond maximum negative
                if decimal_part == 0:
                    expand_inds = int_part_inds
                else:
                    expand_inds = torch.cat((torch.arange(pos_exp.size(0))[(pos_exp.size(0) - decimal_part):].long().cuda(), int_part_inds), 0)  
                
                result.pos = pos_exp[expand_inds]
                result.neg = neg_exp
                result.anchor = Variable(torch.zeros(result.pos.size(0)).cuda())
                # some variables .register_hook(extract_grad)

                return result

            else:  # no gt_rel in rel_inds
                print("no gt_rel in rel_inds!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                twod_inds = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data
                result.obj_scores = F.softmax(result.rm_obj_dists.detach(), dim=1).view(-1)[twod_inds]

                # positive overall score
                obj_scores0 = result.obj_scores[rel_inds[:,1]]
                obj_scores1 = result.obj_scores[rel_inds[:,2]]
                rel_rep = F.softmax(result.rel_dists, dim=1)    # [275, 51]
                _, pred_classes_argmax = rel_rep.data[:,:].max(1)  # all classes
                max_rel_score = rel_rep.gather(1, Variable(pred_classes_argmax.view(-1,1))).squeeze() # SqueezeBackward, GatherBackward
                prob_score = max_rel_score * obj_scores0.detach() * obj_scores1.detach()
                #pos_prob[:,1][result.rel_labels.data[:,3] == 0] = 0  # treat most rel_labels as neg because their rel cls is 0 "unknown"  
                
                # negative overall score
                obj_scores0_neg = result.obj_scores[rel_inds_neg[:,1]]
                obj_scores1_neg = result.obj_scores[rel_inds_neg[:,2]]
                rel_rep_neg = F.softmax(rel_dists_neg, dim=1)   
                _, pred_classes_argmax_neg = rel_rep_neg.data[:,:].max(1)  # all classes
                max_rel_score_neg = rel_rep_neg.gather(1, Variable(pred_classes_argmax_neg.view(-1,1))).squeeze() # SqueezeBackward, GatherBackward
                prob_score_neg = max_rel_score_neg * obj_scores0_neg.detach() * obj_scores1_neg.detach()

                # use all rel_inds, already irrelavant with im_inds, which is only use to extract region from img and produce rel_inds
                # 384 boxes---(rel_inds)(rel_inds_neg)--->prob_score,prob_score_neg 
                all_rel_inds = torch.cat((result.rel_labels.data, rel_labels_neg), 0)  # [#pos_inds+#neg_inds, 4]
                flag = torch.cat((torch.ones(prob_score.size(0),1).cuda(),torch.zeros(prob_score_neg.size(0),1).cuda()),0)
                all_prob = torch.cat((prob_score,prob_score_neg), 0)  # Variable, [#pos_inds+#neg_inds, 1]

                _, sort_prob_inds = torch.sort(all_prob.data, dim=0, descending=True)

                sorted_rel_inds = all_rel_inds[sort_prob_inds]
                sorted_flag = flag[sort_prob_inds].squeeze()  # can be used to check distribution of pos and neg
                sorted_all_prob = all_prob[sort_prob_inds]  # Variable

                pos_sorted_inds = sorted_rel_inds.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 1).view(-1,4)
                neg_sorted_inds = sorted_rel_inds.masked_select(sorted_flag.view(-1,1).expand(-1,4).cuda() == 0).view(-1,4)
                pos_exp = sorted_all_prob[sorted_flag == 1]  # Variable  
                neg_exp = sorted_all_prob[sorted_flag == 0]  # Variable

                int_part = neg_exp.size(0) // pos_exp.size(0)
                decimal_part = neg_exp.size(0) % pos_exp.size(0)
                int_inds = torch.arange(pos_exp.data.size(0))[:,None].expand_as(torch.Tensor(pos_exp.data.size(0), int_part)).contiguous().view(-1)
                int_part_inds = (int(pos_exp.data.size(0) -1) - int_inds).long().cuda() # use minimum pos to correspond maximum negative
                if decimal_part == 0:
                    expand_inds = int_part_inds
                else:
                    expand_inds = torch.cat((torch.arange(pos_exp.size(0))[(pos_exp.size(0) - decimal_part):].long().cuda(), int_part_inds), 0)  
                
                result.pos = pos_exp[expand_inds]
                result.neg = neg_exp
                result.anchor = Variable(torch.zeros(result.pos.size(0)).cuda())

                return result
        ###################### Testing ###########################

        # extract corrsponding scores according to the box's preds
        twod_inds = arange(result.rm_obj_preds.data) * self.num_classes + result.rm_obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds]   # [384]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)    # [275, 51]
        
        # sort product of obj1 * obj2 * rel
        return filter_dets(bboxes, result.obj_scores,
                           result.rm_obj_preds, rel_inds[:, 1:],
                           rel_rep)


    def __getitem__(self, batch):
        """ Hack to do multi-GPU training"""
        batch.scatter()
        if self.num_gpus == 1:
            return self(*batch[0])

        replicas = nn.parallel.replicate(self, devices=list(range(self.num_gpus)))
        outputs = nn.parallel.parallel_apply(replicas, [batch[i] for i in range(self.num_gpus)])

        if self.training:
            return gather_res(outputs, 0, dim=0)
        return outputs
Пример #6
0
    def __init__(self,
                 classes,
                 rel_classes,
                 embed_dim,
                 obj_dim,
                 inputs_dim,
                 hidden_dim,
                 pooling_dim,
                 recurrent_dropout_probability=0.2,
                 use_highway=True,
                 use_input_projection_bias=True,
                 use_vision=True,
                 use_bias=True,
                 use_tanh=True,
                 limit_vision=True,
                 sl_pretrain=False,
                 num_iter=-1):
        """
        Initializes the RNN
        :param embed_dim: Dimension of the embeddings
        :param encoder_hidden_dim: Hidden dim of the encoder, for attention purposes
        :param hidden_dim: Hidden dim of the decoder
        :param vocab_size: Number of words in the vocab
        :param bos_token: To use during decoding (non teacher forcing mode))
        :param bos: beginning of sentence token
        :param unk: unknown token (not used)
        """
        super(DecoderRNN, self).__init__()

        self.rel_embedding_dim = 100
        self.classes = classes
        self.rel_classes = rel_classes
        embed_vecs = obj_edge_vectors(['start'] + self.classes, wv_dim=100)
        self.obj_embed = nn.Embedding(len(self.classes), embed_dim)
        self.obj_embed.weight.data = embed_vecs

        embed_rels = obj_edge_vectors(self.rel_classes,
                                      wv_dim=self.rel_embedding_dim)
        self.rel_embed = nn.Embedding(len(self.rel_classes),
                                      self.rel_embedding_dim)
        self.rel_embed.weight.data = embed_rels

        self.embed_dim = embed_dim
        self.obj_dim = obj_dim
        self.hidden_size = hidden_dim
        self.inputs_dim = inputs_dim
        self.pooling_dim = pooling_dim
        self.nms_thresh = 0.3

        self.use_vision = use_vision
        self.use_bias = use_bias
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.sl_pretrain = sl_pretrain
        self.num_iter = num_iter

        self.recurrent_dropout_probability = recurrent_dropout_probability
        self.use_highway = use_highway
        # We do the projections for all the gates all at once, so if we are
        # using highway layers, we need some extra projections, which is
        # why the sizes of the Linear layers change here depending on this flag.
        if use_highway:
            self.input_linearity = torch.nn.Linear(
                self.input_size,
                6 * self.hidden_size,
                bias=use_input_projection_bias)
            self.state_linearity = torch.nn.Linear(self.hidden_size,
                                                   5 * self.hidden_size,
                                                   bias=True)
        else:
            self.input_linearity = torch.nn.Linear(
                self.input_size,
                4 * self.hidden_size,
                bias=use_input_projection_bias)
            self.state_linearity = torch.nn.Linear(self.hidden_size,
                                                   4 * self.hidden_size,
                                                   bias=True)

        # self.obj_in_lin = torch.nn.Linear(self.rel_embedding_dim, self.rel_embedding_dim, bias=True)

        self.out = nn.Linear(self.hidden_size, len(self.classes))
        self.reset_parameters()

        # For relation predication
        embed_vecs2 = obj_edge_vectors(self.classes, wv_dim=embed_dim)
        self.obj_embed2 = nn.Embedding(self.num_classes, embed_dim)
        self.obj_embed2.weight.data = embed_vecs2.clone()

        # self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2)
        self.post_lstm = nn.Linear(self.obj_dim + 2 * self.embed_dim + 128,
                                   self.pooling_dim * 2)
        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.
        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(
            0, 10.0 * math.sqrt(1.0 / self.hidden_size)
        )  ######## there may need more consideration
        self.post_lstm.bias.data.zero_()

        self.rel_compress = nn.Linear(self.pooling_dim,
                                      self.num_rels,
                                      bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()

            # simple relation model
            from dataloaders.visual_genome import VG
            from lib.get_dataset_counts import get_counts, box_filter
            fg_matrix, bg_matrix = get_counts(train_data=VG.splits(
                num_val_im=5000,
                filter_non_overlap=True,
                filter_duplicate_rels=True,
                use_proposals=False)[0],
                                              must_overlap=True)
            prob_matrix = fg_matrix.astype(np.float32)
            prob_matrix[:, :, 0] = bg_matrix

            # TRYING SOMETHING NEW.
            prob_matrix[:, :, 0] += 1
            prob_matrix /= np.sum(prob_matrix, 2)[:, :, None]
            # prob_matrix /= float(fg_matrix.max())

            prob_matrix[:, :, 0] = 0  # Zero out BG
            self.prob_matrix = prob_matrix
class RelModel(RelModelBase):
    """
    Depth-Fusion relation detection model
    """

    # -- Different components' FC layer size
    FC_SIZE_VISUAL = 512
    FC_SIZE_CLASS = 64
    FC_SIZE_LOC = 20
    FC_SIZE_DEPTH = 4096
    LOC_INPUT_SIZE = 8

    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=False,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=4096,
                 use_resnet=False,
                 thresh=0.01,
                 use_proposals=False,
                 use_bias=True,
                 limit_vision=True,
                 depth_model=None,
                 pretrained_depth=False,
                 active_features=None,
                 frozen_features=None,
                 use_embed=False,
                 **kwargs):
        """
        :param classes: object classes
        :param rel_classes: relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param use_vision: enable the contribution of union of bounding boxes
        :param require_overlap_det: whether two objects must intersect
        :param embed_dim: word2vec embeddings dimension
        :param hidden_dim: dimension of the fusion hidden layer
        :param use_resnet: use resnet as faster-rcnn's backbone
        :param thresh: faster-rcnn related threshold (Threshold for calling it a good box)
        :param use_proposals: whether to use region proposal candidates
        :param use_bias: enable frequency bias
        :param limit_vision: use truncated version of UoBB features
        :param depth_model: provided architecture for depth feature extraction
        :param pretrained_depth: whether the depth feature extractor should be initialized with ImageNet weights
        :param active_features: what set of features should be enabled (e.g. 'vdl' : visual, depth, and location features)
        :param frozen_features: what set of features should be frozen (e.g. 'd' : depth)
        :param use_embed: use word2vec embeddings
        """
        RelModelBase.__init__(self, classes, rel_classes, mode, num_gpus,
                              require_overlap_det, active_features,
                              frozen_features)
        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.use_vision = use_vision
        self.use_bias = use_bias
        self.limit_vision = limit_vision

        # -- Store depth related parameters
        assert depth_model in DEPTH_MODELS
        self.depth_model = depth_model
        self.pretrained_depth = pretrained_depth
        self.depth_pooling_dim = DEPTH_DIMS[self.depth_model]
        self.use_embed = use_embed
        self.detector = nn.Module()
        features_size = 0

        # -- Check whether ResNet is selected as faster-rcnn's backbone
        if use_resnet:
            raise ValueError(
                "The current model does not support ResNet as the Faster-RCNN's backbone."
            )
        """ *** DIFFERENT COMPONENTS OF THE PROPOSED ARCHITECTURE *** 
        This is the part where the different components of the proposed relation detection 
        architecture are defined. In the case of RGB images, we have class probability distribution
        features, visual features, and the location ones. If we are considering depth images as well,
        we augment depth features too. """

        # -- Visual features
        if self.has_visual:
            # -- Define faster R-CNN network and it's related feature extractors
            self.detector = ObjectDetector(
                classes=classes,
                mode=('proposals' if use_proposals else 'refinerels')
                if mode == 'sgdet' else 'gtbox',
                use_resnet=use_resnet,
                thresh=thresh,
                max_per_img=64,
            )
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

            # -- Define union features
            if self.use_vision:
                # -- UoBB pooling module
                self.union_boxes = UnionBoxesAndFeats(
                    pooling_size=self.pooling_size,
                    stride=16,
                    dim=1024 if use_resnet else 512)

                # -- UoBB feature extractor
                roi_fmap = [
                    Flattener(),
                    load_vgg(use_dropout=False,
                             use_relu=False,
                             use_linear=self.hidden_dim == 4096,
                             pretrained=False).classifier,
                ]
                if self.hidden_dim != 4096:
                    roi_fmap.append(nn.Linear(4096, self.hidden_dim))
                self.roi_fmap = nn.Sequential(*roi_fmap)

            # -- Define visual features hidden layer
            self.visual_hlayer = nn.Sequential(*[
                xavier_init(nn.Linear(self.obj_dim * 2, self.FC_SIZE_VISUAL)),
                nn.ReLU(inplace=True),
                nn.Dropout(0.8)
            ])
            self.visual_scale = ScaleLayer(1.0)
            features_size += self.FC_SIZE_VISUAL

        # -- Location features
        if self.has_loc:
            # -- Define location features hidden layer
            self.location_hlayer = nn.Sequential(*[
                xavier_init(nn.Linear(self.LOC_INPUT_SIZE, self.FC_SIZE_LOC)),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1)
            ])
            self.location_scale = ScaleLayer(1.0)
            features_size += self.FC_SIZE_LOC

        # -- Class features
        if self.has_class:
            if self.use_embed:
                # -- Define class embeddings
                embed_vecs = obj_edge_vectors(self.classes,
                                              wv_dim=self.embed_dim)
                self.obj_embed = nn.Embedding(self.num_classes, self.embed_dim)
                self.obj_embed.weight.data = embed_vecs.clone()

            classme_input_dim = self.embed_dim if self.use_embed else self.num_classes
            # -- Define Class features hidden layer
            self.classme_hlayer = nn.Sequential(*[
                xavier_init(
                    nn.Linear(classme_input_dim * 2, self.FC_SIZE_CLASS)),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1)
            ])
            self.classme_scale = ScaleLayer(1.0)
            features_size += self.FC_SIZE_CLASS

        # -- Depth features
        if self.has_depth:
            # -- Initialize depth backbone
            self.depth_backbone = DepthCNN(depth_model=self.depth_model,
                                           pretrained=self.pretrained_depth)

            # -- Create a relation head which is used to carry on the feature extraction
            # from RoIs of depth features
            self.depth_rel_head = self.depth_backbone.get_classifier()

            # -- Define depth features hidden layer
            self.depth_rel_hlayer = nn.Sequential(*[
                xavier_init(
                    nn.Linear(self.depth_pooling_dim * 2, self.FC_SIZE_DEPTH)),
                nn.ReLU(inplace=True),
                nn.Dropout(0.6),
            ])
            self.depth_scale = ScaleLayer(1.0)
            features_size += self.FC_SIZE_DEPTH

        # -- Initialize frequency bias if needed
        if self.use_bias:
            self.freq_bias = FrequencyBias()

        # -- *** Fusion layer *** --
        # -- A hidden layer for concatenated features (fusion features)
        self.fusion_hlayer = nn.Sequential(*[
            xavier_init(nn.Linear(features_size, self.hidden_dim)),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1)
        ])

        # -- Final FC layer which predicts the relations
        self.rel_out = xavier_init(
            nn.Linear(self.hidden_dim, self.num_rels, bias=True))

        # -- Freeze the user specified features
        if self.frz_visual:
            self.freeze_module(self.detector)
            self.freeze_module(self.roi_fmap_obj)
            self.freeze_module(self.visual_hlayer)
            if self.use_vision:
                self.freeze_module(self.roi_fmap)
                self.freeze_module(self.union_boxes.conv)

        if self.frz_class:
            self.freeze_module(self.classme_hlayer)

        if self.frz_loc:
            self.freeze_module(self.location_hlayer)

        if self.frz_depth:
            self.freeze_module(self.depth_backbone)
            self.freeze_module(self.depth_rel_head)
            self.freeze_module(self.depth_rel_hlayer)

    def get_roi_features(self, features, rois):
        """
        Gets ROI features
        :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2)
        :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1].
        :return: [num_rois, #dim] array
        """
        feature_pool = RoIAlign((self.pooling_size, self.pooling_size),
                                spatial_scale=1 / 16,
                                sampling_ratio=-1)(features, rois)
        return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1))

    def get_union_features(self, features, rois, pair_inds):
        """
        Gets UoBB features
        :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4]
        :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1].
        :param pair_inds: inds to use when predicting
        :return: UoBB features
        """
        assert pair_inds.size(1) == 2
        uboxes = self.union_boxes(features, rois, pair_inds)
        return self.roi_fmap(uboxes)

    def get_roi_features_depth(self, features, rois):
        """
        Gets ROI features (depth)
        :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2)
        :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1].
        :return: [num_rois, #dim] array
        """
        feature_pool = RoIAlign((self.pooling_size, self.pooling_size),
                                spatial_scale=1 / 16,
                                sampling_ratio=-1)(features, rois)

        # -- Flatten the layer if the model is not RESNET/SQZNET
        if self.depth_model not in ('resnet18', 'resnet50', 'sqznet'):
            feature_pool = feature_pool.view(rois.size(0), -1)

        return self.depth_rel_head(feature_pool)

    @staticmethod
    def get_loc_features(boxes, subj_inds, obj_inds):
        """
        Calculate the scale-invariant location feature
        :param boxes: ground-truth/detected boxes
        :param subj_inds: subject indices
        :param obj_inds: object indices
        :return: location_feature
        """
        boxes_centered = center_size(boxes.data)

        # -- Determine box's center and size (subj's box)
        center_subj = boxes_centered[subj_inds][:, 0:2]
        size_subj = boxes_centered[subj_inds][:, 2:4]

        # -- Determine box's center and size (obj's box)
        center_obj = boxes_centered[obj_inds][:, 0:2]
        size_obj = boxes_centered[obj_inds][:, 2:4]

        # -- Calculate the scale-invariant location features of the subject
        t_coord_subj = (center_subj - center_obj) / size_obj
        t_size_subj = torch.log(size_subj / size_obj)

        # -- Calculate the scale-invariant location features of the object
        t_coord_obj = (center_obj - center_subj) / size_subj
        t_size_obj = torch.log(size_obj / size_subj)

        # -- Put everything together
        location_feature = Variable(
            torch.cat((t_coord_subj, t_size_subj, t_coord_obj, t_size_obj), 1))
        return location_feature

    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False,
                depth_imgs=None):
        """
        Forward pass for relation detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: a numpy array of (h, w, scale) for each image.
        :param image_offset: offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param gt_rels: [] gt relations
        :param proposals: region proposals retrieved from file
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :param return_fmap: if the object detector must return the extracted feature maps
        :param depth_imgs: depth images [batch_size, 1, IM_SIZE, IM_SIZE]
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """

        if self.has_visual:
            # -- Feed forward the rgb images to Faster-RCNN
            result = self.detector(x,
                                   im_sizes,
                                   image_offset,
                                   gt_boxes,
                                   gt_classes,
                                   gt_rels,
                                   proposals,
                                   train_anchor_inds,
                                   return_fmap=True)
        else:
            # -- Get prior `result` object (instead of calling faster-rcnn's detector)
            result = self.get_prior_results(image_offset, gt_boxes, gt_classes,
                                            gt_rels)

        # -- Get RoI and relations
        rois, rel_inds = self.get_rois_and_rels(result, image_offset, gt_boxes,
                                                gt_classes, gt_rels)
        boxes = result.rm_box_priors

        # -- Determine subject and object indices
        subj_inds = rel_inds[:, 1]
        obj_inds = rel_inds[:, 2]

        # -- Prepare object predictions vector (PredCLS)
        # replace with ground truth labels
        result.obj_preds = result.rm_obj_labels
        # replace with one-hot distribution of ground truth labels
        result.rm_obj_dists = F.one_hot(result.rm_obj_labels.data,
                                        self.num_classes).float()
        obj_cls = result.rm_obj_dists
        result.rm_obj_dists = result.rm_obj_dists * 1000 + (
            1 - result.rm_obj_dists) * (-1000)

        rel_features = []
        # -- Extract RGB features
        if self.has_visual:
            # Feed the extracted features from first conv layers to the last 'classifier' layers (VGG)
            # Here, only the last 3 layers of VGG are being trained. Everything else (in self.detector)
            # is frozen.
            result.obj_fmap = self.get_roi_features(result.fmap.detach(), rois)

            # -- Create a pairwise relation vector out of visual features
            rel_visual = torch.cat(
                (result.obj_fmap[subj_inds], result.obj_fmap[obj_inds]), 1)
            rel_visual_fc = self.visual_hlayer(rel_visual)
            rel_visual_scale = self.visual_scale(rel_visual_fc)
            rel_features.append(rel_visual_scale)

        # -- Extract Location features
        if self.has_loc:
            # -- Create a pairwise relation vector out of location features
            rel_location = self.get_loc_features(boxes, subj_inds, obj_inds)
            rel_location_fc = self.location_hlayer(rel_location)
            rel_location_scale = self.location_scale(rel_location_fc)
            rel_features.append(rel_location_scale)

        # -- Extract Class features
        if self.has_class:
            if self.use_embed:
                obj_cls = obj_cls @ self.obj_embed.weight
            # -- Create a pairwise relation vector out of class features
            rel_classme = torch.cat((obj_cls[subj_inds], obj_cls[obj_inds]), 1)
            rel_classme_fc = self.classme_hlayer(rel_classme)
            rel_classme_scale = self.classme_scale(rel_classme_fc)
            rel_features.append(rel_classme_scale)

        # -- Extract Depth features
        if self.has_depth:
            # -- Extract features from depth backbone
            depth_features = self.depth_backbone(depth_imgs)
            depth_rois_features = self.get_roi_features_depth(
                depth_features, rois)

            # -- Create a pairwise relation vector out of location features
            rel_depth = torch.cat((depth_rois_features[subj_inds],
                                   depth_rois_features[obj_inds]), 1)
            rel_depth_fc = self.depth_rel_hlayer(rel_depth)
            rel_depth_scale = self.depth_scale(rel_depth_fc)
            rel_features.append(rel_depth_scale)

        # -- Create concatenated feature vector
        rel_fusion = torch.cat(rel_features, 1)

        # -- Extract relation embeddings (penultimate layer)
        rel_embeddings = self.fusion_hlayer(rel_fusion)

        # -- Mix relation embeddings with UoBB features
        if self.has_visual and self.use_vision:
            uobb_features = self.get_union_features(result.fmap.detach(), rois,
                                                    rel_inds[:, 1:])
            if self.limit_vision:
                # exact value TBD
                uobb_limit = int(self.hidden_dim / 2)
                rel_embeddings = torch.cat((rel_embeddings[:, :uobb_limit] *
                                            uobb_features[:, :uobb_limit],
                                            rel_embeddings[:, uobb_limit:]), 1)
            else:
                rel_embeddings = rel_embeddings * uobb_features

        # -- Predict relation distances
        result.rel_dists = self.rel_out(rel_embeddings)

        # -- Frequency bias
        if self.use_bias:
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(
                torch.stack((
                    result.obj_preds[rel_inds[:, 1]],
                    result.obj_preds[rel_inds[:, 2]],
                ), 1))

        if self.training:
            return result

        # --- *** END OF ARCHITECTURE *** ---#

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        # Filtering: Subject_Score * Pred_score * Obj_score, sorted and ranked
        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)
Пример #8
0
class RelModel(nn.Module):
    """
    RELATIONSHIPS
    """
    def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True,
                 embed_dim=200, hidden_dim=256, pooling_dim=2048,
                 nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01,
                 use_proposals=False, pass_in_obj_feats_to_decoder=True,
                 pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True,
                 limit_vision=True):

        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param use_vision: Whether to use vision in the final product
        :param require_overlap_det: Whether two objects must intersect
        :param embed_dim: Dimension for all embeddings
        :param hidden_dim: LSTM hidden size
        :param obj_dim:
        """
        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision=limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'
        # print('REL MODEL CONSTRUCTOR: 1')
        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )
        # print('REL MODEL CONSTRUCTOR: 2')
        self.context = LinearizedContext(self.classes, self.rel_classes, mode=self.mode,
                                         embed_dim=self.embed_dim, hidden_dim=self.hidden_dim,
                                         obj_dim=self.obj_dim,
                                         nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout,
                                         order=order,
                                         pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
                                         pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge)
        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16,
                                              dim=1024 if use_resnet else 512)
        # print('REL MODEL CONSTRUCTOR: 3')
        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier
        # print('REL MODEL CONSTRUCTOR: 4')
        ###################################
        self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2)

        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.

        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(0, 10.0 * math.sqrt(1.0 / self.hidden_dim))
        self.post_lstm.bias.data.zero_()
        # print('REL MODEL CONSTRUCTOR: 5')
        if nl_edge == 0:
            self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim*2)
            self.post_emb.weight.data.normal_(0, math.sqrt(1.0))

        self.rel_compress = nn.Linear(self.pooling_dim, self.num_rels, bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()
        # print('REL MODEL CONSTRUCTOR: over')

    @property
    def num_classes(self):
        return len(self.classes)

    @property
    def num_rels(self):
        return len(self.rel_classes)

    def visual_rep(self, features, rois, pair_inds):
        """
        Classify the features
        :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4]
        :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1].
        :param pair_inds inds to use when predicting
        :return: score_pred, a [num_rois, num_classes] array
                 box_pred, a [num_rois, num_classes, 4] array
        """
        assert pair_inds.size(1) == 2
        uboxes = self.union_boxes(features, rois, pair_inds)
        return self.roi_fmap(uboxes)

    def get_rel_inds(self, rel_labels, im_inds, box_priors):
        # Get the relationship candidates
        if self.training:
            rel_inds = rel_labels[:, :3].data.clone()
        else:
            rel_cands = im_inds.data[:, None] == im_inds.data[None]
            rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0

            # Require overlap for detection
            if self.require_overlap:
                rel_cands = rel_cands & (bbox_overlaps(box_priors.data,
                                                       box_priors.data) > 0)

                # if there are fewer then 100 things then we might as well add some?
                amt_to_add = 100 - rel_cands.long().sum()

            rel_cands = rel_cands.nonzero()
            if rel_cands.dim() == 0:
                rel_cands = im_inds.data.new(1, 2).fill_(0)

            rel_inds = torch.cat((im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1)
        return rel_inds

    def obj_feature_map(self, features, rois):
        """
        Gets the ROI features
        :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2)
        :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1].
        :return: [num_rois, #dim] array
        """
        feature_pool = RoIAlignFunction(self.pooling_size, self.pooling_size, spatial_scale=1 / 16)(
            features, rois)
        return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1))

    def forward(self, x, im_sizes, image_offset,
                gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """
        result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
                               train_anchor_inds, return_fmap=True)
        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data,
                                                gt_boxes.data, gt_classes.data, gt_rels.data,
                                                image_offset, filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds, edge_ctx = self.context(
            result.obj_fmap,
            result.rm_obj_dists.detach(),
            im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None,
            boxes.data, result.boxes_all)

        if edge_ctx is None:
            edge_rep = self.post_emb(result.obj_preds)
        else:
            edge_rep = self.post_lstm(edge_ctx)

        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)

        subj_rep = edge_rep[:, 0]
        obj_rep = edge_rep[:, 1]

        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]

        if self.use_vision:
            vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])
            if self.limit_vision:
                # exact value TBD
                prod_rep = torch.cat((prod_rep[:,:2048] * vr[:,:2048], prod_rep[:,2048:]), 1)
            else:
                prod_rep = prod_rep * vr

        if self.use_tanh:
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)

        if self.use_bias:
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(torch.stack((
                result.obj_preds[rel_inds[:, 1]],
                result.obj_preds[rel_inds[:, 2]],
            ), 1))

        if self.training:
            return result

        twod_inds = arange(result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        return filter_dets(bboxes, result.obj_scores,
                           result.obj_preds, rel_inds[:, 1:], rel_rep)

    def __getitem__(self, batch):
        """ Hack to do multi-GPU training"""
        batch.scatter()
        if self.num_gpus == 1:
            return self(*batch[0])

        replicas = nn.parallel.replicate(self, devices=list(range(self.num_gpus)))
        outputs = nn.parallel.parallel_apply(replicas, [batch[i] for i in range(self.num_gpus)])

        if self.training:
            return gather_res(outputs, 0, dim=0)
        return outputs
Пример #9
0
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=256,
                 pooling_dim=2048,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=False,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 gnn=True,
                 reachability=False,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 limit_vision=True):
        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param use_vision: Whether to use vision in the final product
        :param require_overlap_det: Whether two objects must intersect
        :param embed_dim: Dimension for all embeddings
        :param hidden_dim: LSTM hidden size
        :param obj_dim:
        """
        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode
        self.reachability = reachability
        self.gnn = gnn
        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'
        self.global_embedding = EmbeddingImagenet(4096)
        self.global_logist = nn.Linear(4096, 151,
                                       bias=True)  # CosineLinear(4096,150)#
        self.global_logist.weight = torch.nn.init.xavier_normal(
            self.global_logist.weight, gain=1.0)

        self.disc_center = DiscCentroidsLoss(self.num_rels,
                                             self.pooling_dim + 256)
        self.meta_classify = MetaEmbedding_Classifier(
            feat_dim=self.pooling_dim + 256, num_classes=self.num_rels)

        # self.global_rel_logist = nn.Linear(4096, 50 , bias=True)
        # self.global_rel_logist.weight = torch.nn.init.xavier_normal(self.global_rel_logist.weight, gain=1.0)

        # self.global_logist = CosineLinear(4096,150)
        self.global_sub_additive = nn.Linear(4096, 1, bias=True)
        self.global_obj_additive = nn.Linear(4096, 1, bias=True)

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels')
            if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )

        self.context = LinearizedContext(
            self.classes,
            self.rel_classes,
            mode=self.mode,
            embed_dim=self.embed_dim,
            hidden_dim=self.hidden_dim,
            obj_dim=self.obj_dim,
            nl_obj=nl_obj,
            nl_edge=nl_edge,
            dropout_rate=rec_dropout,
            order=order,
            pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
            pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge)

        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False,
                         use_relu=False,
                         use_linear=pooling_dim == 4096,
                         pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        ###################################
        self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2)

        self.edge_coordinate_embedding = nn.Sequential(*[
            nn.BatchNorm1d(5, momentum=BATCHNORM_MOMENTUM / 10.0),
            nn.Linear(5, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
        ])
        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.

        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(
            0, 10.0 * math.sqrt(1.0 / self.hidden_dim))
        self.post_lstm.bias.data.zero_()

        if nl_edge == 0:
            self.post_emb = nn.Embedding(self.num_classes,
                                         self.pooling_dim * 2)
            self.post_emb.weight.data.normal_(0, math.sqrt(1.0))

        self.rel_compress = nn.Linear(4096 + 256, 51, bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)

        self.node_transform = nn.Linear(4096, 256, bias=True)
        self.edge_transform = nn.Linear(4096, 256, bias=True)
        # self.rel_compress = CosineLinear(self.pooling_dim+256, self.num_rels)
        # self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()
        if self.gnn:
            self.graph_network_node = GraphNetwork(4096)
            self.graph_network_edge = GraphNetwork()
            if self.training:
                self.graph_network_node.train()
                self.graph_network_edge.train()
            else:
                self.graph_network_node.eval()
                self.graph_network_edge.eval()
        self.edge_sim_network = nn.Linear(4096, 1, bias=True)
        self.metric_net = MetricLearning()
Пример #10
0
class RelModel(nn.Module):
    """
    RELATIONSHIPS
    """
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=256,
                 pooling_dim=2048,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=False,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 model_path='',
                 reachability=False,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 init_center=False,
                 limit_vision=True):

        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode
        self.init_center = init_center
        self.pooling_size = 7
        self.model_path = model_path
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim
        self.centroids = None
        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'
        self.global_embedding = EmbeddingImagenet(4096)
        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels')
            if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )

        self.context = LinearizedContext(
            self.classes,
            self.rel_classes,
            mode=self.mode,
            embed_dim=self.embed_dim,
            hidden_dim=self.hidden_dim,
            obj_dim=self.obj_dim,
            nl_obj=nl_obj,
            nl_edge=nl_edge,
            dropout_rate=rec_dropout,
            order=order,
            pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
            pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge)

        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512)
        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False,
                         use_relu=False,
                         use_linear=pooling_dim == 4096,
                         pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        ###################################
        self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2)
        self.disc_center = DiscCentroidsLoss(self.num_rels, self.pooling_dim)
        self.meta_classify = MetaEmbedding_Classifier(
            feat_dim=self.pooling_dim, num_classes=self.num_rels)
        self.disc_center_g = DiscCentroidsLoss(self.num_classes,
                                               self.pooling_dim)
        self.meta_classify_g = MetaEmbedding_Classifier(
            feat_dim=self.pooling_dim, num_classes=self.num_classes)
        self.global_sub_additive = nn.Linear(4096, 1, bias=True)
        self.global_obj_additive = nn.Linear(4096, 1, bias=True)
        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.

        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(
            0, 10.0 * math.sqrt(1.0 / self.hidden_dim))
        self.post_lstm.bias.data.zero_()

        self.global_logist = nn.Linear(self.pooling_dim,
                                       self.num_classes,
                                       bias=True)  # CosineLinear(4096,150)#
        self.global_logist.weight = torch.nn.init.xavier_normal(
            self.global_logist.weight, gain=1.0)

        self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim * 2)
        self.post_emb.weight.data.normal_(0, math.sqrt(1.0))

        self.rel_compress = nn.Linear(self.pooling_dim,
                                      self.num_rels,
                                      bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()
        self.class_num = torch.zeros(len(self.classes))
        self.centroids = torch.zeros(len(self.classes),
                                     self.pooling_dim).cuda()

    @property
    def num_classes(self):
        return len(self.classes)

    @property
    def num_rels(self):
        return len(self.rel_classes)

    def visual_rep(self, features, rois, pair_inds):

        assert pair_inds.size(1) == 2
        uboxes = self.union_boxes(features, rois, pair_inds)
        return self.roi_fmap(uboxes)

    def get_rel_inds(self, rel_labels, im_inds, box_priors):
        # Get the relationship candidates
        if self.training:
            rel_inds = rel_labels[:, :3].data.clone()
        else:
            rel_cands = im_inds.data[:, None] == im_inds.data[None]
            rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0

            # Require overlap for detection
            if self.require_overlap:
                rel_cands = rel_cands & (bbox_overlaps(box_priors.data,
                                                       box_priors.data) > 0)

                # if there are fewer then 100 things then we might as well add some?
                amt_to_add = 100 - rel_cands.long().sum()

            rel_cands = rel_cands.nonzero()
            if rel_cands.dim() == 0:
                rel_cands = im_inds.data.new(1, 2).fill_(0)

            rel_inds = torch.cat(
                (im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1)
        return rel_inds

    def _to_one_hot(self, y, n_dims, dtype=torch.cuda.FloatTensor):
        scatter_dim = len(y.size())
        y_tensor = y.type(torch.cuda.LongTensor).view(*y.size(), -1)
        zeros = torch.zeros(*y.size(), n_dims).type(dtype)

        return zeros.scatter(scatter_dim, y_tensor, 1)

    def obj_feature_map(self, features, rois):

        feature_pool = RoIAlignFunction(self.pooling_size,
                                        self.pooling_size,
                                        spatial_scale=1 / 16)(features, rois)
        return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1))

    def center_calulate(self, feature, labels):
        for idx, i in enumerate(labels):
            self.centroids[i] += feature[idx]
            self.class_num[i] += 1

    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):

        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)
        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds, node_rep0 = self.context(
            result.obj_fmap, result.rm_obj_dists.detach(), im_inds,
            result.rm_obj_labels if self.training or self.mode == 'predcls'
            else None, boxes.data, result.boxes_all)

        edge_rep = node_rep0.repeat(1, 2)
        edge_rep = edge_rep.view(edge_rep.size(0), 2, -1)

        global_feature = self.global_embedding(result.fmap.detach())
        result.global_dists = self.global_logist(global_feature)
        one_hot_multi = torch.zeros(
            (result.global_dists.shape[0], self.num_classes))

        one_hot_multi[im_inds, result.rm_obj_labels] = 1.0
        result.multi_hot = one_hot_multi.float().cuda()

        subj_global_additive_attention = F.relu(
            self.global_sub_additive(edge_rep[:, 0] + global_feature[im_inds]))
        obj_global_additive_attention = F.relu(
            self.global_obj_additive(edge_rep[:, 1] + global_feature[im_inds]))

        subj_rep = edge_rep[:,
                            0] + subj_global_additive_attention * global_feature[
                                im_inds]
        obj_rep = edge_rep[:,
                           1] + obj_global_additive_attention * global_feature[
                               im_inds]

        if self.training:
            self.centroids = self.disc_center.centroids.data

        # if edge_ctx is None:
        #     edge_rep = self.post_emb(result.obj_preds)
        # else:
        #     edge_rep = self.post_lstm(edge_ctx)

        # Split into subject and object representations
        # edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)
        #
        # subj_rep = edge_rep[:, 0]
        # obj_rep = edge_rep[:, 1]

        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]

        vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])

        prod_rep = prod_rep * vr

        prod_rep = F.tanh(prod_rep)

        logits, self.direct_memory_feature = self.meta_classify(
            prod_rep, self.centroids)
        # result.rel_dists = self.rel_compress(prod_rep)
        result.rel_dists = logits
        result.rel_dists2 = self.direct_memory_feature[-1]
        # result.hallucinate_logits = self.direct_memory_feature[-1]
        if self.training:
            result.center_loss = self.disc_center(
                prod_rep, result.rel_labels[:, -1]) * 0.01

        if self.use_bias:
            result.rel_dists = result.rel_dists + 1.0 * self.freq_bias.index_with_labels(
                torch.stack((
                    result.obj_preds[rel_inds[:, 1]],
                    result.obj_preds[rel_inds[:, 2]],
                ), 1))

        if self.training:
            return result

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)

    def __getitem__(self, batch):
        """ Hack to do multi-GPU training"""
        batch.scatter()
        if self.num_gpus == 1:
            return self(*batch[0])

        replicas = nn.parallel.replicate(self,
                                         devices=list(range(self.num_gpus)))
        outputs = nn.parallel.parallel_apply(
            replicas, [batch[i] for i in range(self.num_gpus)])

        if self.training:
            return gather_res(outputs, 0, dim=0)
        return outputs
Пример #11
0
class EndCell(nn.Module):
    def __init__(self,
                 classes,
                 num_rels,
                 mode='sgdet',
                 embed_dim=200,
                 pooling_dim=4096,
                 use_bias=True):

        super(EndCell, self).__init__()
        self.classes = classes
        self.num_rels = num_rels
        assert mode in MODES
        self.embed_dim = embed_dim
        self.pooling_dim = pooling_dim
        self.use_bias = use_bias
        self.mode = mode
        self.ort_embedding = torch.autograd.Variable(
            get_ort_embeds(self.num_classes, self.embed_dim).cuda())
        self.context = LC(classes=self.classes,
                          mode=self.mode,
                          embed_dim=self.embed_dim,
                          obj_dim=self.pooling_dim)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=7,
                                              stride=16,
                                              dim=512)
        self.pooling_size = 7

        roi_fmap = [
            Flattener(),
            load_vgg(use_dropout=False,
                     use_relu=False,
                     use_linear=pooling_dim == 4096,
                     pretrained=False).classifier,
        ]
        if pooling_dim != 4096:
            roi_fmap.append(nn.Linear(4096, pooling_dim))
        self.roi_fmap = nn.Sequential(*roi_fmap)
        self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        self.post_lstm = nn.Linear(self.pooling_dim + self.embed_dim + 5,
                                   self.pooling_dim * 2)

        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.

        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(
            0, 10.0 * math.sqrt(1.0 / self.pooling_dim))
        self.post_lstm.bias.data.zero_()

        self.post_emb = nn.Linear(self.pooling_dim + self.embed_dim + 5,
                                  self.pooling_dim * 2)

        self.rel_compress = nn.Linear(self.pooling_dim,
                                      self.num_rels,
                                      bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()

    @property
    def num_classes(self):
        return len(self.classes)

    def visual_rep(self, features, rois, pair_inds):
        """
        Classify the features
        :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4]
        :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1].
        :param pair_inds inds to use when predicting
        :return: score_pred, a [num_rois, num_classes] array
                 box_pred, a [num_rois, num_classes, 4] array
        """
        assert pair_inds.size(1) == 2
        uboxes = self.union_boxes(features, rois, pair_inds)
        return self.roi_fmap(uboxes)

    def visual_obj(self, features, rois, pair_inds):
        assert pair_inds.size(1) == 2
        uboxes = self.union_boxes(features, rois, pair_inds)
        return uboxes

    def get_rel_inds(self, rel_labels, im_inds, box_priors):
        # Get the relationship candidates
        if self.training:
            rel_inds = rel_labels[:, :3].data.clone()
        else:
            rel_cands = im_inds.data[:, None] == im_inds.data[None]
            rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0

            # Require overlap for detection
            if self.require_overlap:
                rel_cands = rel_cands & (bbox_overlaps(box_priors.data,
                                                       box_priors.data) > 0)

                # if there are fewer then 100 things then we might as well add some?
                amt_to_add = 100 - rel_cands.long().sum()

            rel_cands = rel_cands.nonzero()
            if rel_cands.dim() == 0:
                rel_cands = im_inds.data.new(1, 2).fill_(0)

            rel_inds = torch.cat(
                (im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1)
        return rel_inds

    def obj_feature_map(self, features, rois):
        """
        Gets the ROI features
        :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2)
        :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1].
        :return: [num_rois, #dim] array
        """
        feature_pool = RoIAlignFunction(self.pooling_size,
                                        self.pooling_size,
                                        spatial_scale=1 / 16)(features, rois)
        return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1))

    def forward(self, last_outputs, obj_dists, rel_inds, im_inds, rois, boxes):

        twod_inds = arange(last_outputs.obj_preds.data
                           ) * self.num_classes + last_outputs.obj_preds.data
        obj_scores = F.softmax(last_outputs.rm_obj_dists,
                               dim=1).view(-1)[twod_inds]

        rel_rep, _ = F.softmax(last_outputs.rel_dists, dim=1)[:, 1:].max(1)
        rel_scores_argmaxed = rel_rep * obj_scores[
            rel_inds[:, 0]] * obj_scores[rel_inds[:, 1]]
        _, rel_scores_idx = torch.sort(rel_scores_argmaxed.view(-1),
                                       dim=0,
                                       descending=True)
        rel_scores_idx = rel_scores_idx[:100]

        filtered_rel_inds = rel_inds[rel_scores_idx.data]

        obj_fmap = self.obj_feature_map(last_outputs.fmap.detach(), rois)

        rm_obj_dists, obj_preds = self.context(
            obj_fmap, obj_dists.detach(), im_inds,
            last_outputs.rm_obj_labels if self.mode == 'predcls' else None,
            boxes.data, last_outputs.boxes_all)

        obj_dtype = obj_fmap.data.type()
        obj_preds_embeds = torch.index_select(self.ort_embedding, 0,
                                              obj_preds).type(obj_dtype)
        transfered_boxes = torch.stack(
            (boxes[:, 0] / IM_SCALE, boxes[:, 3] / IM_SCALE,
             boxes[:, 2] / IM_SCALE, boxes[:, 1] / IM_SCALE,
             ((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])) /
             (IM_SCALE**2)), -1).type(obj_dtype)
        obj_features = torch.cat(
            (obj_fmap, obj_preds_embeds, transfered_boxes), -1)
        edge_rep = self.post_emb(obj_features)

        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)

        subj_rep = edge_rep[:, 0][filtered_rel_inds[:, 1]]
        obj_rep = edge_rep[:, 1][filtered_rel_inds[:, 2]]

        prod_rep = subj_rep * obj_rep

        vr = self.visual_rep(last_outputs.fmap.detach(), rois,
                             filtered_rel_inds[:, 1:])

        prod_rep = prod_rep * vr

        rel_dists = self.rel_compress(prod_rep)

        if self.use_bias:
            rel_dists = rel_dists + self.freq_bias.index_with_labels(
                torch.stack((
                    obj_preds[filtered_rel_inds[:, 1]],
                    obj_preds[filtered_rel_inds[:, 2]],
                ), 1))

        return filtered_rel_inds, rm_obj_dists, obj_preds, rel_dists
Пример #12
0
    def __init__(self, classes, rel_classes, mode='sgdet', use_vision=True,
                 embed_dim=200, hidden_dim=256, obj_dim=2048, pooling_dim=2048,
                 pooling_size=7, dropout_rate=0.2, use_bias=True, use_tanh=True, 
                 limit_vision=True, sl_pretrain=False, num_iter=-1, use_resnet=False,
                 reduce_input=False, debug_type=None, post_nms_thresh=0.5):
        super(DynamicFilterContext, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        assert mode in MODES
        self.mode = mode

        self.use_vision = use_vision 
        self.use_bias = use_bias
        self.use_tanh = use_tanh
        self.use_highway = True
        self.limit_vision = limit_vision

        self.pooling_dim = pooling_dim 
        self.pooling_size = pooling_size
        self.nms_thresh = post_nms_thresh
        
        self.obj_compress = myNNLinear(self.pooling_dim, self.num_classes, bias=True)

        # self.roi_fmap_obj = load_vgg(pretrained=False).classifier
        roi_fmap_obj = [myNNLinear(512*self.pooling_size*self.pooling_size, 4096, bias=True),
                        nn.ReLU(inplace=True),
                        nn.Dropout(p=0.5),
                        myNNLinear(4096, 4096, bias=True),
                        nn.ReLU(inplace=True),
                        nn.Dropout(p=0.5)]
        self.roi_fmap_obj = nn.Sequential(*roi_fmap_obj)

        if self.use_bias:
            self.freq_bias = FrequencyBias()

        self.reduce_dim = 256
        self.reduce_obj_fmaps = nn.Conv2d(512, self.reduce_dim, kernel_size=1)

        similar_fun = [myNNLinear(self.reduce_dim*2, self.reduce_dim),
                       nn.ReLU(inplace=True),
                       myNNLinear(self.reduce_dim, 1)]
        self.similar_fun = nn.Sequential(*similar_fun)


        # roi_fmap = [Flattener(),
        #     load_vgg(use_dropout=False, use_relu=False, use_linear=self.pooling_dim == 4096, pretrained=False).classifier,]
        # if self.pooling_dim != 4096:
        #     roi_fmap.append(nn.Linear(4096, self.pooling_dim))
        # self.roi_fmap = nn.Sequential(*roi_fmap)
        roi_fmap = [Flattener(),
                    nn.Linear(self.reduce_dim*2*self.pooling_size*self.pooling_size, 4096, bias=True),
                    nn.ReLU(inplace=True),
                    nn.Dropout(p=0.5),
                    nn.Linear(4096, 4096, bias=True)]
        self.roi_fmap = nn.Sequential(*roi_fmap)

        self.hidden_dim = hidden_dim
        self.rel_compress = myNNLinear(self.hidden_dim*3, self.num_rels)
        self.post_obj = myNNLinear(self.pooling_dim, self.hidden_dim*2)
        self.mapping_x = myNNLinear(self.hidden_dim*2, self.hidden_dim*3)
        self.reduce_rel_input = myNNLinear(self.pooling_dim, self.hidden_dim*3)
Пример #13
0
class DynamicFilterContext(nn.Module):

    def __init__(self, classes, rel_classes, mode='sgdet', use_vision=True,
                 embed_dim=200, hidden_dim=256, obj_dim=2048, pooling_dim=2048,
                 pooling_size=7, dropout_rate=0.2, use_bias=True, use_tanh=True, 
                 limit_vision=True, sl_pretrain=False, num_iter=-1, use_resnet=False,
                 reduce_input=False, debug_type=None, post_nms_thresh=0.5):
        super(DynamicFilterContext, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        assert mode in MODES
        self.mode = mode

        self.use_vision = use_vision 
        self.use_bias = use_bias
        self.use_tanh = use_tanh
        self.use_highway = True
        self.limit_vision = limit_vision

        self.pooling_dim = pooling_dim 
        self.pooling_size = pooling_size
        self.nms_thresh = post_nms_thresh
        
        self.obj_compress = myNNLinear(self.pooling_dim, self.num_classes, bias=True)

        # self.roi_fmap_obj = load_vgg(pretrained=False).classifier
        roi_fmap_obj = [myNNLinear(512*self.pooling_size*self.pooling_size, 4096, bias=True),
                        nn.ReLU(inplace=True),
                        nn.Dropout(p=0.5),
                        myNNLinear(4096, 4096, bias=True),
                        nn.ReLU(inplace=True),
                        nn.Dropout(p=0.5)]
        self.roi_fmap_obj = nn.Sequential(*roi_fmap_obj)

        if self.use_bias:
            self.freq_bias = FrequencyBias()

        self.reduce_dim = 256
        self.reduce_obj_fmaps = nn.Conv2d(512, self.reduce_dim, kernel_size=1)

        similar_fun = [myNNLinear(self.reduce_dim*2, self.reduce_dim),
                       nn.ReLU(inplace=True),
                       myNNLinear(self.reduce_dim, 1)]
        self.similar_fun = nn.Sequential(*similar_fun)


        # roi_fmap = [Flattener(),
        #     load_vgg(use_dropout=False, use_relu=False, use_linear=self.pooling_dim == 4096, pretrained=False).classifier,]
        # if self.pooling_dim != 4096:
        #     roi_fmap.append(nn.Linear(4096, self.pooling_dim))
        # self.roi_fmap = nn.Sequential(*roi_fmap)
        roi_fmap = [Flattener(),
                    nn.Linear(self.reduce_dim*2*self.pooling_size*self.pooling_size, 4096, bias=True),
                    nn.ReLU(inplace=True),
                    nn.Dropout(p=0.5),
                    nn.Linear(4096, 4096, bias=True)]
        self.roi_fmap = nn.Sequential(*roi_fmap)

        self.hidden_dim = hidden_dim
        self.rel_compress = myNNLinear(self.hidden_dim*3, self.num_rels)
        self.post_obj = myNNLinear(self.pooling_dim, self.hidden_dim*2)
        self.mapping_x = myNNLinear(self.hidden_dim*2, self.hidden_dim*3)
        self.reduce_rel_input = myNNLinear(self.pooling_dim, self.hidden_dim*3)


    def obj_feature_map(self, features, rois):
        feature_pool = RoIAlignFunction(self.pooling_size, self.pooling_size, spatial_scale=1 / 16)(
            features, rois)
        return feature_pool
        # return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1))

    @property
    def num_classes(self):
        return len(self.classes)
    @property
    def num_rels(self):
        return len(self.rel_classes)

    @property
    def is_sgdet(self):
        return self.mode == 'sgdet'
    
    @property
    def is_sgcls(self):
        return self.mode == 'sgcls'

    def forward(self, *args, **kwargs):
        
        results = self.base_forward(*args, **kwargs)
        return results

    def base_forward(self, fmaps, obj_logits, im_inds, rel_inds, msg_rel_inds, reward_rel_inds, im_sizes, boxes_priors=None, boxes_deltas=None, boxes_per_cls=None, obj_labels=None):

        assert self.mode == 'sgcls'
        num_objs = obj_logits.shape[0]
        num_rels = rel_inds.shape[0]

        rois = torch.cat((im_inds[:, None].float(), boxes_priors), 1)
        obj_fmaps = self.obj_feature_map(fmaps, rois)
        reduce_obj_fmaps = self.reduce_obj_fmaps(obj_fmaps)

        S_fmaps = reduce_obj_fmaps[rel_inds[:, 1]]
        O_fmaps = reduce_obj_fmaps[rel_inds[:, 2]]

        if conf.debug_type in ['test1_0']:
            last_SO_fmaps = torch.cat((S_fmaps, O_fmaps), dim=1)
        
        elif conf.debug_type in ['test1_1']:

            S_fmaps_trans = S_fmaps.view(num_rels, self.reduce_dim, self.pooling_size*self.pooling_size).transpose(2, 1)
            O_fmaps_trans = O_fmaps.view(num_rels, self.reduce_dim, self.pooling_size*self.pooling_size).transpose(2, 1)

            pooling_size_sq = self.pooling_size*self.pooling_size
            S_fmaps_extend = S_fmaps_trans.repeat(1, 1, pooling_size_sq).view(num_rels, pooling_size_sq*pooling_size_sq, self.reduce_dim)
            O_fmaps_extend = O_fmaps_trans.repeat(1, pooling_size_sq, 1)

            SO_fmaps_extend = torch.cat((S_fmaps_extend, O_fmaps_extend), dim=2)
            SO_fmaps_logits = self.similar_fun(SO_fmaps_extend)
            SO_fmaps_logits = SO_fmaps_logits.view(num_rels, pooling_size_sq, pooling_size_sq) # (first dim is S_fmaps, second dim is O_fmaps)

            SO_fmaps_scores = F.softmax(SO_fmaps_logits, dim=1)

            weighted_S_fmaps = torch.matmul(SO_fmaps_scores.transpose(2, 1), S_fmaps_trans) # (num_rels, 49, 49) x (num_rels, 49, self.reduce_dim)

            last_SO_fmaps = torch.cat((weighted_S_fmaps, O_fmaps_trans), dim=2)
            last_SO_fmaps = last_SO_fmaps.transpose(2, 1).contiguous().view(num_rels, self.reduce_dim*2, self.pooling_size, self.pooling_size)
        else:
            raise ValueError

        # for object classification
        obj_feats = self.roi_fmap_obj(obj_fmaps.view(rois.size(0), -1))
        obj_logits = self.obj_compress(obj_feats)
        obj_dists = F.softmax(obj_logits, dim=1)
        pred_obj_cls = obj_dists[:, 1:].max(1)[1] + 1

        # for relationship classification
        rel_input = self.roi_fmap(last_SO_fmaps)
        subobj_rep = self.post_obj(obj_feats)
        sub_rep = subobj_rep[:, :self.hidden_dim][rel_inds[:, 1]]
        obj_rep = subobj_rep[:, self.hidden_dim:][rel_inds[:, 2]]

        last_rel_input = self.reduce_rel_input(rel_input)
        last_obj_input = self.mapping_x(torch.cat((sub_rep, obj_rep), 1))
        triple_rep = nn.ReLU(inplace=True)(last_obj_input + last_rel_input) - (last_obj_input - last_rel_input).pow(2)

        rel_logits = self.rel_compress(triple_rep)

        # follow neural-motifs paper
        if self.use_bias:
            if self.mode in ['sgcls', 'sgdet']:
                rel_logits = rel_logits + self.freq_bias.index_with_labels(
                    torch.stack((
                        pred_obj_cls[rel_inds[:, 1]],
                        pred_obj_cls[rel_inds[:, 2]],
                        ), 1))
            elif self.mode == 'predcls':
                rel_logits = rel_logits + self.freq_bias.index_with_labels(
                    torch.stack((
                        obj_labels[rel_inds[:, 1]],
                        obj_labels[rel_inds[:, 2]],
                        ), 1))
            else:
                raise NotImplementedError

        return pred_obj_cls, obj_logits, rel_logits
Пример #14
0
    def __init__(self,
                 train_data,
                 mode='sgcls',
                 require_overlap_det=True,
                 use_bias=False,
                 test_bias=False,
                 backbone='vgg16',
                 RELS_PER_IMG=1024,
                 min_size=None,
                 max_size=None,
                 edge_model='motifs'):
        """
        Base class for an SGG model
        :param mode: (sgcls, predcls, or sgdet)
        :param require_overlap_det: Whether two objects must intersect
        """
        super(RelModelBase, self).__init__()
        self.classes = train_data.ind_to_classes
        self.rel_classes = train_data.ind_to_predicates
        self.mode = mode
        self.backbone = backbone
        self.RELS_PER_IMG = RELS_PER_IMG
        self.pool_sz = 7
        self.stride = 16

        self.use_bias = use_bias
        self.test_bias = test_bias

        self.require_overlap = require_overlap_det and self.mode == 'sgdet'

        if self.backbone == 'resnet50':
            self.obj_dim = 1024
            self.fmap_sz = 21

            if min_size is None:
                min_size = 1333
            if max_size is None:
                max_size = 1333

            print('\nLoading COCO pretrained model maskrcnn_resnet50_fpn...\n')
            # See https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
            self.detector = torchvision.models.detection.maskrcnn_resnet50_fpn(
                pretrained=True,
                min_size=min_size,
                max_size=max_size,
                box_detections_per_img=50,
                box_score_thresh=0.2)
            in_features = self.detector.roi_heads.box_predictor.cls_score.in_features
            # replace the pre-trained head with a new one
            self.detector.roi_heads.box_predictor = FastRCNNPredictor(
                in_features, len(self.classes))
            self.detector.roi_heads.mask_predictor = None

            layers = list(self.detector.roi_heads.children())[:2]
            self.roi_fmap_obj = copy.deepcopy(layers[1])
            self.roi_fmap = copy.deepcopy(layers[1])
            self.roi_pool = copy.deepcopy(layers[0])

        elif self.backbone == 'vgg16':
            self.obj_dim = 4096
            self.fmap_sz = 38

            if min_size is None:
                min_size = IM_SCALE
            if max_size is None:
                max_size = IM_SCALE

            vgg = load_vgg(use_dropout=False,
                           use_relu=False,
                           use_linear=True,
                           pretrained=False)
            vgg.features.out_channels = 512
            anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256,
                                                       512), ),
                                               aspect_ratios=((0.5, 1.0,
                                                               2.0), ))

            roi_pooler = torchvision.ops.MultiScaleRoIAlign(
                featmap_names=['0'],
                output_size=self.pool_sz,
                sampling_ratio=2)

            self.detector = FasterRCNN(vgg.features,
                                       min_size=min_size,
                                       max_size=max_size,
                                       rpn_anchor_generator=anchor_generator,
                                       box_head=TwoMLPHead(
                                           vgg.features.out_channels *
                                           self.pool_sz**2, self.obj_dim),
                                       box_predictor=FastRCNNPredictor(
                                           self.obj_dim,
                                           len(train_data.ind_to_classes)),
                                       box_roi_pool=roi_pooler,
                                       box_detections_per_img=50,
                                       box_score_thresh=0.2)

            self.roi_fmap = nn.Sequential(nn.Flatten(), vgg.classifier)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier
            self.roi_pool = copy.deepcopy(
                list(self.detector.roi_heads.children())[0])

        else:
            raise NotImplementedError(self.backbone)

        self.edge_dim = self.detector.backbone.out_channels

        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pool_sz,
                                              stride=self.stride,
                                              dim=self.edge_dim,
                                              edge_model=edge_model)
        if self.use_bias:
            self.freq_bias = FrequencyBias(train_data)
Пример #15
0
class RelModelLinknet(nn.Module):
    """
    RELATIONSHIPS
    """
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=256,
                 pooling_dim=4096,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=False,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 limit_vision=True):
        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param use_vision: Whether to use vision in the final product
        :param require_overlap_det: Whether two objects must intersect
        :param embed_dim: Dimension for all embeddings
        :param hidden_dim: LSTM hidden size
        :param obj_dim:
        """
        super(RelModelLinknet, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.ctx_dim = 1024 if use_resnet else 512
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels')
            if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )

        self.context = LinearizedContext(self.classes,
                                         self.rel_classes,
                                         mode=self.mode,
                                         embed_dim=self.embed_dim,
                                         hidden_dim=self.hidden_dim,
                                         obj_dim=self.obj_dim,
                                         pooling_dim=self.pooling_dim,
                                         ctx_dim=self.ctx_dim)

        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False,
                         use_relu=False,
                         use_linear=pooling_dim == 4096,
                         pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        # Global Context Encoding
        self.GCE = GlobalContextEncoding(num_classes=self.num_classes,
                                         ctx_dim=self.ctx_dim)

        ###################################

        # K2
        self.pos_embed = nn.Sequential(*[
            nn.BatchNorm1d(4, momentum=BATCHNORM_MOMENTUM / 10.0),
            nn.Linear(4, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
        ])

        # fc4
        self.rel_compress = nn.Linear(self.pooling_dim + 128,
                                      self.num_rels,
                                      bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)

        if self.use_bias:
            self.freq_bias = FrequencyBias()

    @property
    def num_classes(self):
        return len(self.classes)

    @property
    def num_rels(self):
        return len(self.rel_classes)

    def visual_rep(self, features, rois, pair_inds):
        """
        Classify the features
        :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4]
        :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1].
        :param pair_inds inds to use when predicting
        :return: score_pred, a [num_rois, num_classes] array
                 box_pred, a [num_rois, num_classes, 4] array
        """
        assert pair_inds.size(1) == 2
        uboxes = self.union_boxes(features, rois, pair_inds)
        return self.roi_fmap(uboxes)

    def get_rel_inds(self, rel_labels, im_inds, box_priors):
        # Get the relationship candidates
        if self.training:
            rel_inds = rel_labels[:, :3].data.clone()
        else:
            rel_cands = im_inds.data[:, None] == im_inds.data[None]
            rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0

            # Require overlap for detection
            if self.require_overlap:
                rel_cands = rel_cands & (bbox_overlaps(box_priors.data,
                                                       box_priors.data) > 0)

                # if there are fewer then 100 things then we might as well add some?
                amt_to_add = 100 - rel_cands.long().sum()

            rel_cands = rel_cands.nonzero()
            if rel_cands.dim() == 0:
                rel_cands = im_inds.data.new(1, 2).fill_(0)

            rel_inds = torch.cat(
                (im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1)
        return rel_inds

    def obj_feature_map(self, features, rois):
        """
        Gets the ROI features
        :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2)
        :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1].
        :return: [num_rois, #dim] array
        """
        feature_pool = RoIAlignFunction(self.pooling_size,
                                        self.pooling_size,
                                        spatial_scale=1 / 16)(features, rois)
        return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1))

    def geo_layout_enc(self, box_priors, rel_inds):
        """
        Geometric Layout Encoding
        :param box_priors: [num_rois, 4] of (xmin, ymin, xmax, ymax)
        :param rel_inds: [num_rels, 3] of (img ind, box0 ind, box1 ind)
        :return: bos: [num_rois*(num_rois-1), 4] encoded relative geometric layout: bo|s
        """
        cxcywh = center_size(box_priors.data)  # convert to (cx, cy, w, h)
        box_s = cxcywh[rel_inds[:, 1]]
        box_o = cxcywh[rel_inds[:, 2]]

        # relative location
        rlt_loc_x = torch.div((box_o[:, 0] - box_s[:, 0]),
                              box_s[:, 2]).view(-1, 1)
        rlt_loc_y = torch.div((box_o[:, 1] - box_s[:, 1]),
                              box_s[:, 3]).view(-1, 1)

        # scale information
        scl_info_w = torch.log(torch.div(box_o[:, 2], box_s[:, 2])).view(-1, 1)
        scl_info_h = torch.log(torch.div(box_o[:, 3], box_s[:, 3])).view(-1, 1)

        bos = torch.cat((rlt_loc_x, rlt_loc_y, scl_info_w, scl_info_h), 1)
        return bos

    def glb_context_enc(self, features, im_inds, gt_classes, image_offset):
        """
        Global Context Encoding
        :param features: [batch_size, ctx_dim, IM_SIZE/4, IM_SIZE/4] fmap features
        :param im_ind: [num_rois] image index
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :return: context_features: [num_rois, ctx_dim] stacked context_feature c according to im_ind
                 gce_obj_dists: [batch_size, num_classes] softmax of predicted multi-label distribution: M
                 gce_obj_labels: [batch_size, num_classes] ground truth multi-labels
        """
        context_feature, gce_obj_dists = self.GCE(features)
        context_features = context_feature[im_inds]

        gce_obj_labels = torch.zeros_like(gce_obj_dists)
        gce_obj_labels[gt_classes[:, 0] - image_offset, gt_classes[:, 1]] = 1

        return context_features, gce_obj_dists, gce_obj_labels

    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for relationship
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels

            if test:
            prob dists, boxes, img inds, maxscores, classes

        """
        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)
        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        # c M
        context_features, result.gce_obj_dists, result.gce_obj_labels = self.glb_context_enc(
            result.fmap.detach(), im_inds.data, gt_classes.data, image_offset)

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds, edge_rep = self.context(
            result.obj_fmap, result.rm_obj_dists.detach(),
            context_features.detach(), result.rm_obj_labels if self.training
            or self.mode == 'predcls' else None, result.boxes_all)

        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)  # E1
        subj_rep = edge_rep[:, 0]  # E1_s
        obj_rep = edge_rep[:, 1]  # E1_o

        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]  # G0

        if self.use_vision:
            vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:,
                                                                      1:])  # F
            if self.limit_vision:
                # exact value TBD
                prod_rep = torch.cat(
                    (prod_rep[:, :2048] * vr[:, :2048], prod_rep[:, 2048:]), 1)
            else:
                prod_rep = prod_rep * vr

        if self.use_tanh:
            prod_rep = F.tanh(prod_rep)

        bos = self.geo_layout_enc(boxes, rel_inds)  # bo|s
        pos_embed = self.pos_embed(Variable(bos))

        result.rel_dists = self.rel_compress(
            torch.cat((prod_rep, pos_embed), 1))  # G2

        if self.use_bias:
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(
                torch.stack((
                    result.obj_preds[rel_inds[:, 1]],
                    result.obj_preds[rel_inds[:, 2]],
                ), 1))

        if self.training:
            return result

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)

    def __getitem__(self, batch):
        """ Hack to do multi-GPU training"""
        batch.scatter()
        if self.num_gpus == 1:
            return self(*batch[0])

        replicas = nn.parallel.replicate(self,
                                         devices=list(range(self.num_gpus)))
        outputs = nn.parallel.parallel_apply(
            replicas, [batch[i] for i in range(self.num_gpus)])

        if self.training:
            return gather_res(outputs, 0, dim=0)
        return outputs
Пример #16
0
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=256,
                 pooling_dim=2048,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=False,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 model_path='',
                 reachability=False,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 init_center=False,
                 limit_vision=True):

        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode
        self.init_center = init_center
        self.pooling_size = 7
        self.model_path = model_path
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim
        self.centroids = None
        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'
        self.global_embedding = EmbeddingImagenet(4096)
        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels')
            if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )

        self.context = LinearizedContext(
            self.classes,
            self.rel_classes,
            mode=self.mode,
            embed_dim=self.embed_dim,
            hidden_dim=self.hidden_dim,
            obj_dim=self.obj_dim,
            nl_obj=nl_obj,
            nl_edge=nl_edge,
            dropout_rate=rec_dropout,
            order=order,
            pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
            pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge)

        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512)
        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False,
                         use_relu=False,
                         use_linear=pooling_dim == 4096,
                         pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        ###################################
        self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2)
        self.disc_center = DiscCentroidsLoss(self.num_rels, self.pooling_dim)
        self.meta_classify = MetaEmbedding_Classifier(
            feat_dim=self.pooling_dim, num_classes=self.num_rels)
        self.disc_center_g = DiscCentroidsLoss(self.num_classes,
                                               self.pooling_dim)
        self.meta_classify_g = MetaEmbedding_Classifier(
            feat_dim=self.pooling_dim, num_classes=self.num_classes)
        self.global_sub_additive = nn.Linear(4096, 1, bias=True)
        self.global_obj_additive = nn.Linear(4096, 1, bias=True)
        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.

        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(
            0, 10.0 * math.sqrt(1.0 / self.hidden_dim))
        self.post_lstm.bias.data.zero_()

        self.global_logist = nn.Linear(self.pooling_dim,
                                       self.num_classes,
                                       bias=True)  # CosineLinear(4096,150)#
        self.global_logist.weight = torch.nn.init.xavier_normal(
            self.global_logist.weight, gain=1.0)

        self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim * 2)
        self.post_emb.weight.data.normal_(0, math.sqrt(1.0))

        self.rel_compress = nn.Linear(self.pooling_dim,
                                      self.num_rels,
                                      bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()
        self.class_num = torch.zeros(len(self.classes))
        self.centroids = torch.zeros(len(self.classes),
                                     self.pooling_dim).cuda()
Пример #17
0
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=256,
                 pooling_dim=4096,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=False,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 limit_vision=True):
        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param use_vision: Whether to use vision in the final product
        :param require_overlap_det: Whether two objects must intersect
        :param embed_dim: Dimension for all embeddings
        :param hidden_dim: LSTM hidden size
        :param obj_dim:
        """
        super(RelModelLinknet, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.ctx_dim = 1024 if use_resnet else 512
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels')
            if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )

        self.context = LinearizedContext(self.classes,
                                         self.rel_classes,
                                         mode=self.mode,
                                         embed_dim=self.embed_dim,
                                         hidden_dim=self.hidden_dim,
                                         obj_dim=self.obj_dim,
                                         pooling_dim=self.pooling_dim,
                                         ctx_dim=self.ctx_dim)

        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False,
                         use_relu=False,
                         use_linear=pooling_dim == 4096,
                         pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        # Global Context Encoding
        self.GCE = GlobalContextEncoding(num_classes=self.num_classes,
                                         ctx_dim=self.ctx_dim)

        ###################################

        # K2
        self.pos_embed = nn.Sequential(*[
            nn.BatchNorm1d(4, momentum=BATCHNORM_MOMENTUM / 10.0),
            nn.Linear(4, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
        ])

        # fc4
        self.rel_compress = nn.Linear(self.pooling_dim + 128,
                                      self.num_rels,
                                      bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)

        if self.use_bias:
            self.freq_bias = FrequencyBias()
Пример #18
0
class RelModel(nn.Module):
    """
    RELATIONSHIPS
    """
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=256,
                 pooling_dim=2048,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=False,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 limit_vision=True):
        """
        Args:
            classes: list, list of 151 object class names(including background)
            rel_classes: list, list of 51 predicate names( including background(norelationship))
            mode: string, 'sgdet', 'predcls' or 'sgcls'
            num_gpus: integer, number of GPUs to use
            use_vision: boolean, whether to use vision in the final product
            require_overlap_det: boolean, whether two object must intersect
            embed_dim: integer, number of dimension for all embeddings
            hidden_dim: integer, hidden size of LSTM
            pooling_dim: integer, outputsize of vgg fc layer
            nl_obj: integer, number of object context layer, 2 in paper
            nl_edge: integer, number of edge context layer, 4 in paper
            use_resnet: integer, use resnet for backbone
            order: string, value must be in ('size', 'confidence', 'random', 'leftright'), order of RoIs
            thresh: float, threshold for scores of boxes
                if score of box smaller than thresh, then it will be abandoned
            use_proposals: boolean, whether to use proposals
            pass_in_obj_feats_to_decoder: boolean, whether to pass object features to decoder RNN
            pass_in_obj_feats_to_edge: boolean, whether to pass object features to edge context RNN
            rec_dropout: float, dropout rate in RNN
            use_bias: boolean,
            use_tanh: boolean,
            limit_vision: boolean,
        """
        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels')
            if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )

        self.context = LinearizedContext(
            self.classes,
            self.rel_classes,
            mode=self.mode,
            embed_dim=self.embed_dim,
            hidden_dim=self.hidden_dim,
            obj_dim=self.obj_dim,
            nl_obj=nl_obj,
            nl_edge=nl_edge,
            dropout_rate=rec_dropout,
            order=order,
            pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
            pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge)

        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False,
                         use_relu=False,
                         use_linear=pooling_dim == 4096,
                         pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        ###################################
        self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2)

        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.

        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(
            0, 10.0 * math.sqrt(1.0 / self.hidden_dim))
        self.post_lstm.bias.data.zero_()

        if nl_edge == 0:
            self.post_emb = nn.Embedding(self.num_classes,
                                         self.pooling_dim * 2)
            self.post_emb.weight.data.normal_(0, math.sqrt(1.0))

        self.rel_compress = nn.Linear(self.pooling_dim,
                                      self.num_rels,
                                      bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()

    @property
    def num_classes(self):
        return len(self.classes)

    @property
    def num_rels(self):
        return len(self.rel_classes)

    def visual_rep(self, features, rois, pair_inds):
        """
        Classify the features
        :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4]
        :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1].
        :param pair_inds inds to use when predicting
        :return: score_pred, a [num_rois, num_classes] array
                 box_pred, a [num_rois, num_classes, 4] array
        """
        assert pair_inds.size(1) == 2
        uboxes = self.union_boxes(features, rois, pair_inds)
        return self.roi_fmap(uboxes)

    def get_rel_inds(self, rel_labels, im_inds, box_priors):
        # Get the relationship candidates
        if self.training:
            rel_inds = rel_labels[:, :3].data.clone()
        else:
            rel_cands = im_inds.data[:, None] == im_inds.data[None]
            rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0

            # Require overlap for detection
            if self.require_overlap:
                rel_cands = rel_cands & (bbox_overlaps(box_priors.data,
                                                       box_priors.data) > 0)

                # if there are fewer then 100 things then we might as well add some?
                amt_to_add = 100 - rel_cands.long().sum()

            rel_cands = rel_cands.nonzero()
            if rel_cands.dim() == 0:
                rel_cands = im_inds.data.new(1, 2).fill_(0)

            rel_inds = torch.cat(
                (im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1)
        return rel_inds

    def obj_feature_map(self, features, rois):
        """
        Gets the ROI features
        :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2)
        :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1].
        :return: [num_rois, #dim] array
        """
        feature_pool = RoIAlignFunction(self.pooling_size,
                                        self.pooling_size,
                                        spatial_scale=1 / 16)(features, rois)
        return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1))

    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):
        """Forward pass for detection
        Args:
            x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
            im_sizes: A numpy array of (h, w, scale) for each image.
            image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)

            Training parameters:
            gt_boxes: [num_gt, 4] GT boxes over the batch.
            gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
            gt_rels:
            proposals:
            train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
            return_fmap:

        Returns:
            If train:
                scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            If test:
                prob dists, boxes, img inds, maxscores, classes
            
        """
        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)
        """
        Results attributes:
            od_obj_dists: digits after score_fc in RCNN
            rm_obj_dists: od_obj_dists after nms
            obj_scores: nmn 
            obj_preds=None, 
            obj_fmap=None,
            od_box_deltas=None, 
            rm_box_deltas=None,
            od_box_targets=None, 
            rm_box_targets=None, 
            od_box_priors: proposal before nms
            rm_box_priors: proposal after nms
            boxes_assigned=None, 
            boxes_all=None, 
            od_obj_labels=None, 
            rm_obj_labels=None,
            rpn_scores=None, 
            rpn_box_deltas=None, 
            rel_labels=None,
            im_inds: image index of every proposals
            fmap=None, 
            rel_dists=None, 
            rel_inds=None, 
            rel_rep=None
            
            one example:
           sgcls task: 
            result.fmap: torch.Size([6, 512, 37, 37])
result.im_inds: torch.Size([44])
result.obj_fmap: torch.Size([44, 4096])
result.od_box_priors: torch.Size([44, 4])
result.od_obj_dists: torch.Size([44, 151])
result.od_obj_labels: torch.Size([44])
result.rel_labels: torch.Size([316, 4])
result.rm_box_priors: torch.Size([44, 4])
result.rm_obj_dists: torch.Size([44, 151])
result.rm_obj_labels: torch.Size([44])
        """
        if result.is_none():
            return ValueError("heck")

        # image_offset refer to Blob
        # self.batch_size_per_gpu * index
        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        #embed(header='rel_model.py before rel_assignments')
        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'

            # only in sgdet mode

            # shapes:
            # im_inds: (box_num,)
            # boxes: (box_num, 4)
            # rm_obj_labels: (box_num,)
            # gt_boxes: (box_num, 4)
            # gt_classes: (box_num, 2) maybe[im_ind, class_ind]
            # gt_rels: (rel_num, 4)
            # image_offset: integer
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)
        #embed(header='rel_model.py after rel_assignments')

        # rel_labels[:, :3] if sgcls
        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        # obj_fmap: (NumOfRoI, 4096)
        # RoIAlign
        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds, edge_ctx = self.context(
            result.obj_fmap, result.rm_obj_dists.detach(), im_inds,
            result.rm_obj_labels if self.training or self.mode == 'predcls'
            else None, boxes.data, result.boxes_all)

        if edge_ctx is None:
            edge_rep = self.post_emb(result.obj_preds)
        else:
            edge_rep = self.post_lstm(edge_ctx)

        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)

        subj_rep = edge_rep[:, 0]
        obj_rep = edge_rep[:, 1]

        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]
        # embed(header='rel_model.py prod_rep')

        if self.use_vision:
            vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])
            if self.limit_vision:
                # exact value TBD
                prod_rep = torch.cat(
                    (prod_rep[:, :2048] * vr[:, :2048], prod_rep[:, 2048:]), 1)
            else:
                prod_rep = prod_rep * vr

        if self.use_tanh:
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)

        if self.use_bias:
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(
                torch.stack((
                    result.obj_preds[rel_inds[:, 1]],
                    result.obj_preds[rel_inds[:, 2]],
                ), 1))

        #embed(header='rel model return ')
        if self.training:
            # embed(header='rel_model.py before return')
            # what will be useful:
            # rm_obj_dists, rm_obj_labels
            # rel_labels, rel_dists
            return result

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)
        #embed(header='rel_model.py before return')
        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)

    def __getitem__(self, batch):
        """ Hack to do multi-GPU training"""
        batch.scatter()
        if self.num_gpus == 1:
            return self(*batch[0])

        replicas = nn.parallel.replicate(self,
                                         devices=list(range(self.num_gpus)))
        outputs = nn.parallel.parallel_apply(
            replicas, [batch[i] for i in range(self.num_gpus)])

        if self.training:
            return gather_res(outputs, 0, dim=0)
        return outputs
Пример #19
0
class RelModel(nn.Module):
    """
    RELATIONSHIPS
    """
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=256,
                 pooling_dim=2048,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=False,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 gnn=True,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 limit_vision=True):
        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param use_vision: Whether to use vision in the final product
        :param require_overlap_det: Whether two objects must intersect
        :param embed_dim: Dimension for all embeddings
        :param hidden_dim: LSTM hidden size
        :param obj_dim:
        """
        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode
        self.gnn = gnn
        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'
        self.global_embedding = EmbeddingImagenet(4096)
        self.global_logist = nn.Linear(4096, 151,
                                       bias=True)  # CosineLinear(4096,150)#
        self.global_logist.weight = torch.nn.init.xavier_normal(
            self.global_logist.weight, gain=1.0)

        # self.global_rel_logist = nn.Linear(4096, 50 , bias=True)
        # self.global_rel_logist.weight = torch.nn.init.xavier_normal(self.global_rel_logist.weight, gain=1.0)

        # self.global_logist = CosineLinear(4096,150)
        self.global_sub_additive = nn.Linear(4096, 1, bias=True)
        self.global_obj_additive = nn.Linear(4096, 1, bias=True)

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels')
            if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )

        self.context = LinearizedContext(
            self.classes,
            self.rel_classes,
            mode=self.mode,
            embed_dim=self.embed_dim,
            hidden_dim=self.hidden_dim,
            obj_dim=self.obj_dim,
            nl_obj=nl_obj,
            nl_edge=nl_edge,
            dropout_rate=rec_dropout,
            order=order,
            pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
            pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge)

        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False,
                         use_relu=False,
                         use_linear=pooling_dim == 4096,
                         pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        ###################################
        self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2)

        self.edge_coordinate_embedding = nn.Sequential(*[
            nn.BatchNorm1d(5, momentum=BATCHNORM_MOMENTUM / 10.0),
            nn.Linear(5, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
        ])
        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.

        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(
            0, 10.0 * math.sqrt(1.0 / self.hidden_dim))
        self.post_lstm.bias.data.zero_()

        if nl_edge == 0:
            self.post_emb = nn.Embedding(self.num_classes,
                                         self.pooling_dim * 2)
            self.post_emb.weight.data.normal_(0, math.sqrt(1.0))

        self.rel_compress = nn.Linear(4096 + 256, 51, bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)

        self.node_transform = nn.Linear(4096, 256, bias=True)
        self.edge_transform = nn.Linear(4096, 256, bias=True)
        # self.rel_compress = CosineLinear(self.pooling_dim+256, self.num_rels)
        # self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()
        if self.gnn:
            self.graph_network_node = GraphNetwork(4096)
            self.graph_network_edge = GraphNetwork()
            if self.training:
                self.graph_network_node.train()
                self.graph_network_edge.train()
            else:
                self.graph_network_node.eval()
                self.graph_network_edge.eval()
        self.edge_sim_network = nn.Linear(4096, 1, bias=True)
        self.metric_net = MetricLearning()

    @property
    def num_classes(self):
        return len(self.classes)

    @property
    def num_rels(self):
        return len(self.rel_classes)

    def visual_rep(self, features, rois, pair_inds):
        """
        Classify the features
        :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4]
        :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1].
        :param pair_inds inds to use when predicting
        :return: score_pred, a [num_rois, num_classes] array
                 box_pred, a [num_rois, num_classes, 4] array
        """
        assert pair_inds.size(1) == 2
        uboxes = self.union_boxes(features, rois, pair_inds)
        return self.roi_fmap(uboxes)

    def get_rel_inds(self, rel_labels, im_inds, box_priors):
        # Get the relationship candidates
        if self.training:
            rel_inds = rel_labels[:, :3].data.clone()
        else:
            rel_cands = im_inds.data[:, None] == im_inds.data[None]
            rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0

            # Require overlap for detection
            if self.require_overlap:
                rel_cands = rel_cands & (bbox_overlaps(box_priors.data,
                                                       box_priors.data) > 0)

                # if there are fewer then 100 things then we might as well add some?
                amt_to_add = 100 - rel_cands.long().sum()

            rel_cands = rel_cands.nonzero()
            if rel_cands.dim() == 0:
                rel_cands = im_inds.data.new(1, 2).fill_(0)

            rel_inds = torch.cat(
                (im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1)
        return rel_inds

    def obj_feature_map(self, features, rois):
        """
        Gets the ROI features
        :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2)
        :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1].
        :return: [num_rois, #dim] array
        """
        feature_pool = RoIAlignFunction(self.pooling_size,
                                        self.pooling_size,
                                        spatial_scale=1 / 16)(features, rois)
        return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1))

    def coordinate_feats(self, boxes, rel_inds):
        coordinate_rep = {}
        coordinate_rep['center'] = center_size(boxes)
        coordinate_rep['point'] = torch.cat(
            (boxes, coordinate_rep['center'][:, 2:]), 1)
        sub_coordnate = {}
        sub_coordnate['center'] = coordinate_rep['center'][rel_inds[:, 1]]
        sub_coordnate['point'] = coordinate_rep['point'][rel_inds[:, 1]]

        obj_coordnate = {}
        obj_coordnate['center'] = coordinate_rep['center'][rel_inds[:, 2]]
        obj_coordnate['point'] = coordinate_rep['point'][rel_inds[:, 2]]
        edge_of_coordinate_rep = torch.zeros(sub_coordnate['center'].size(0),
                                             5).cuda().float()
        edge_of_coordinate_rep[:, 0] = (sub_coordnate['point'][:, 0] - obj_coordnate['center'][:, 0]) * 1.0 / \
                                       obj_coordnate['center'][:, 2]
        edge_of_coordinate_rep[:, 1] = (sub_coordnate['point'][:, 1] - obj_coordnate['center'][:, 1]) * 1.0 / \
                                       obj_coordnate['center'][:, 3]
        edge_of_coordinate_rep[:, 2] = (sub_coordnate['point'][:, 2] - obj_coordnate['center'][:, 0]) * 1.0 / \
                                       obj_coordnate['center'][:, 2]
        edge_of_coordinate_rep[:, 3] = (sub_coordnate['point'][:, 3] - obj_coordnate['center'][:, 1]) * 1.0 / \
                                       obj_coordnate['center'][:, 3]
        edge_of_coordinate_rep[:, 4] = sub_coordnate['point'][:, 4] * sub_coordnate['point'][:, 5] * 1.0 / \
                                       obj_coordnate['center'][:, 2] \
                                       / obj_coordnate['center'][:, 3]
        return edge_of_coordinate_rep

    def forward(self,
                x,
                im_sizes,
                image_offset,
                gt_boxes=None,
                gt_classes=None,
                gt_rels=None,
                proposals=None,
                train_anchor_inds=None,
                return_fmap=False):

        result = self.detector(x,
                               im_sizes,
                               image_offset,
                               gt_boxes,
                               gt_classes,
                               gt_rels,
                               proposals,
                               train_anchor_inds,
                               return_fmap=True)

        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels = rel_assignments(im_inds.data,
                                                boxes.data,
                                                result.rm_obj_labels.data,
                                                gt_boxes.data,
                                                gt_classes.data,
                                                gt_rels.data,
                                                image_offset,
                                                filter_non_overlap=True,
                                                num_sample_per_gt=1)

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        global_feature = self.global_embedding(result.fmap.detach())
        result.global_dists = self.global_logist(global_feature)
        # print(result.global_dists)
        # result.global_rel_dists = F.sigmoid(self.global_rel_logist(global_feature))

        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds, node_rep0 = self.context(
            result.obj_fmap, result.rm_obj_dists.detach(), im_inds,
            result.rm_obj_labels if self.training or self.mode == 'predcls'
            else None, boxes.data, result.boxes_all)

        one_hot_multi = torch.zeros(
            (result.global_dists.shape[0], self.num_classes))

        one_hot_multi[im_inds, result.rm_obj_labels] = 1.0
        result.multi_hot = one_hot_multi.float().cuda()
        edge_rep = node_rep0.repeat(1, 2)

        edge_rep = edge_rep.view(edge_rep.size(0), 2, -1)
        global_feature_re = global_feature[im_inds]
        subj_global_additive_attention = F.relu(
            self.global_sub_additive(edge_rep[:, 0] + global_feature_re))
        obj_global_additive_attention = F.relu(
            torch.sigmoid(
                self.global_obj_additive(edge_rep[:, 1] + global_feature_re)))

        subj_rep = edge_rep[:,
                            0] + subj_global_additive_attention * global_feature_re
        obj_rep = edge_rep[:,
                           1] + obj_global_additive_attention * global_feature_re

        edge_of_coordinate_rep = self.coordinate_feats(boxes.data, rel_inds)

        e_ij_coordinate_rep = self.edge_coordinate_embedding(
            edge_of_coordinate_rep)

        union_rep = self.visual_rep(result.fmap.detach(), rois, rel_inds[:,
                                                                         1:])
        edge_feat_init = union_rep

        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[
            rel_inds[:, 2]] * edge_feat_init
        prod_rep = torch.cat((prod_rep, e_ij_coordinate_rep), 1)

        if self.use_tanh:
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)

        if self.use_bias:
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(
                torch.stack((
                    result.obj_preds[rel_inds[:, 1]],
                    result.obj_preds[rel_inds[:, 2]],
                ), 1))

        if self.training:
            return result

        twod_inds = arange(
            result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists,
                                      dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(
                result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)

        return filter_dets(bboxes, result.obj_scores, result.obj_preds,
                           rel_inds[:, 1:], rel_rep)

    def __getitem__(self, batch):
        """ Hack to do multi-GPU training"""
        batch.scatter()
        if self.num_gpus == 1:
            return self(*batch[0])

        replicas = nn.parallel.replicate(self,
                                         devices=list(range(self.num_gpus)))
        outputs = nn.parallel.parallel_apply(
            replicas, [batch[i] for i in range(self.num_gpus)])

        if self.training:
            return gather_res(outputs, 0, dim=0)
        return outputs
Пример #20
0
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=256,
                 pooling_dim=2048,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=False,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 limit_vision=True):
        """
        Args:
            classes: list, list of 151 object class names(including background)
            rel_classes: list, list of 51 predicate names( including background(norelationship))
            mode: string, 'sgdet', 'predcls' or 'sgcls'
            num_gpus: integer, number of GPUs to use
            use_vision: boolean, whether to use vision in the final product
            require_overlap_det: boolean, whether two object must intersect
            embed_dim: integer, number of dimension for all embeddings
            hidden_dim: integer, hidden size of LSTM
            pooling_dim: integer, outputsize of vgg fc layer
            nl_obj: integer, number of object context layer, 2 in paper
            nl_edge: integer, number of edge context layer, 4 in paper
            use_resnet: integer, use resnet for backbone
            order: string, value must be in ('size', 'confidence', 'random', 'leftright'), order of RoIs
            thresh: float, threshold for scores of boxes
                if score of box smaller than thresh, then it will be abandoned
            use_proposals: boolean, whether to use proposals
            pass_in_obj_feats_to_decoder: boolean, whether to pass object features to decoder RNN
            pass_in_obj_feats_to_edge: boolean, whether to pass object features to edge context RNN
            rec_dropout: float, dropout rate in RNN
            use_bias: boolean,
            use_tanh: boolean,
            limit_vision: boolean,
        """
        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels')
            if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )

        self.context = LinearizedContext(
            self.classes,
            self.rel_classes,
            mode=self.mode,
            embed_dim=self.embed_dim,
            hidden_dim=self.hidden_dim,
            obj_dim=self.obj_dim,
            nl_obj=nl_obj,
            nl_edge=nl_edge,
            dropout_rate=rec_dropout,
            order=order,
            pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
            pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge)

        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False,
                         use_relu=False,
                         use_linear=pooling_dim == 4096,
                         pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        ###################################
        self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2)

        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.

        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(
            0, 10.0 * math.sqrt(1.0 / self.hidden_dim))
        self.post_lstm.bias.data.zero_()

        if nl_edge == 0:
            self.post_emb = nn.Embedding(self.num_classes,
                                         self.pooling_dim * 2)
            self.post_emb.weight.data.normal_(0, math.sqrt(1.0))

        self.rel_compress = nn.Linear(self.pooling_dim,
                                      self.num_rels,
                                      bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()
Пример #21
0
    def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True,
                 embed_dim=200, hidden_dim=256, pooling_dim=2048,
                 nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01,
                 use_proposals=False, pass_in_obj_feats_to_decoder=True,
                 pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True,
                 limit_vision=True):

        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param use_vision: Whether to use vision in the final product
        :param require_overlap_det: Whether two objects must intersect
        :param embed_dim: Dimension for all embeddings
        :param hidden_dim: LSTM hidden size
        :param obj_dim:
        """
        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision=limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'
        # print('REL MODEL CONSTRUCTOR: 1')
        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )
        # print('REL MODEL CONSTRUCTOR: 2')
        self.context = LinearizedContext(self.classes, self.rel_classes, mode=self.mode,
                                         embed_dim=self.embed_dim, hidden_dim=self.hidden_dim,
                                         obj_dim=self.obj_dim,
                                         nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout,
                                         order=order,
                                         pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
                                         pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge)
        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16,
                                              dim=1024 if use_resnet else 512)
        # print('REL MODEL CONSTRUCTOR: 3')
        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier
        # print('REL MODEL CONSTRUCTOR: 4')
        ###################################
        self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2)

        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.

        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(0, 10.0 * math.sqrt(1.0 / self.hidden_dim))
        self.post_lstm.bias.data.zero_()
        # print('REL MODEL CONSTRUCTOR: 5')
        if nl_edge == 0:
            self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim*2)
            self.post_emb.weight.data.normal_(0, math.sqrt(1.0))

        self.rel_compress = nn.Linear(self.pooling_dim, self.num_rels, bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()
class RelModel(nn.Module):
    """
    RELATIONSHIPS
    """
    def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True,
                 embed_dim=200, hidden_dim=256, pooling_dim=2048,
                 nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01,
                 use_proposals=False, pass_in_obj_feats_to_decoder=True,
                 pass_in_obj_feats_to_edge=True, rec_dropout=0.1, use_bias=True, use_tanh=True, use_encoded_box=True, use_rl_tree=True, draw_tree=False,
                 limit_vision=True):

        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param use_vision: Whether to use vision in the final product
        :param require_overlap_det: Whether two objects must intersect
        :param embed_dim: Dimension for all embeddings
        :param hidden_dim: LSTM hidden size
        :param obj_dim:
        """
        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode
        self.co_occour = np.load(CO_OCCOUR_PATH)
        self.co_occour = self.co_occour / self.co_occour.sum()

        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.use_encoded_box = use_encoded_box
        self.use_rl_tree = use_rl_tree
        self.draw_tree = draw_tree
        self.limit_vision=limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'
        self.rl_train = False

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
            use_rl_tree = self.use_rl_tree
        )

        self.context = LinearizedContext(self.classes, self.rel_classes, mode=self.mode,
                                         embed_dim=self.embed_dim, hidden_dim=self.hidden_dim,
                                         obj_dim=self.obj_dim,
                                         nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout,
                                         order=order,
                                         pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
                                         pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge,
                                         use_rl_tree=self.use_rl_tree,
                                         draw_tree = self.draw_tree)

        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16,
                                              dim=1024 if use_resnet else 512)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(use_dropout=False, pretrained=False).classifier


        ###################################
        self.post_lstm = nn.Linear(self.hidden_dim, self.hidden_dim * 2)
        self.post_cat = nn.Linear(self.hidden_dim * 2, self.pooling_dim)
        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.

        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(0, 10.0 * math.sqrt(1.0 / self.hidden_dim))
        self.post_lstm.bias.data.zero_()
        self.post_cat.weight = torch.nn.init.xavier_normal(self.post_cat.weight,  gain=1.0)
        self.post_cat.bias.data.zero_()

        if self.use_encoded_box:
            # encode spatial info
            self.encode_spatial_1 = nn.Linear(32, 512)
            self.encode_spatial_2 = nn.Linear(512, self.pooling_dim)

            self.encode_spatial_1.weight.data.normal_(0, 1.0)
            self.encode_spatial_1.bias.data.zero_()
            self.encode_spatial_2.weight.data.normal_(0, 0.1)
            self.encode_spatial_2.bias.data.zero_()

        if nl_edge == 0:
            self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim*2)
            self.post_emb.weight.data.normal_(0, math.sqrt(1.0))

        self.rel_compress = nn.Linear(self.pooling_dim, self.num_rels, bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()

    @property
    def num_classes(self):
        return len(self.classes)

    @property
    def num_rels(self):
        return len(self.rel_classes)

    def visual_rep(self, features, rois, pair_inds):
        """
        Classify the features
        :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4]
        :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1].
        :param pair_inds inds to use when predicting
        :return: score_pred, a [num_rois, num_classes] array
                 box_pred, a [num_rois, num_classes, 4] array
        """
        assert pair_inds.size(1) == 2
        uboxes = self.union_boxes(features, rois, pair_inds)
        return self.roi_fmap(uboxes)

    def get_rel_inds(self, rel_labels, im_inds, box_priors):
        # Get the relationship candidates
        if self.training and not self.use_rl_tree:
            rel_inds = rel_labels[:, :3].data.clone()
        else:
            rel_cands = im_inds.data[:, None] == im_inds.data[None]
            rel_cands.view(-1)[diagonal_inds(rel_cands)] = 0

            # Require overlap for detection
            if self.require_overlap:
                rel_cands = rel_cands & (bbox_overlaps(box_priors.data,
                                                       box_priors.data) > 0)

                # if there are fewer then 100 things then we might as well add some?
                amt_to_add = 100 - rel_cands.long().sum()

            rel_cands = rel_cands.nonzero()
            if rel_cands.dim() == 0:
                rel_cands = im_inds.data.new(1, 2).fill_(0)

            rel_inds = torch.cat((im_inds.data[rel_cands[:, 0]][:, None], rel_cands), 1)
        return rel_inds

    def obj_feature_map(self, features, rois):
        """
        Gets the ROI features
        :param features: [batch_size, dim, IM_SIZE/4, IM_SIZE/4] (features at level p2)
        :param rois: [num_rois, 5] array of [img_num, x0, y0, x1, y1].
        :return: [num_rois, #dim] array
        """
        feature_pool = RoIAlignFunction(self.pooling_size, self.pooling_size, spatial_scale=1 / 16)(
            features, rois)
        return self.roi_fmap_obj(feature_pool.view(rois.size(0), -1))

    def get_rel_label(self, im_inds, gt_rels, rel_inds):
        np_im_inds = im_inds.data.cpu().numpy()
        np_gt_rels = gt_rels.long().data.cpu().numpy()
        np_rel_inds = rel_inds.long().cpu().numpy()

        num_obj = int(im_inds.shape[0])
        sub_id = np_rel_inds[:, 1]
        obj_id = np_rel_inds[:, 2]
        select_id = sub_id * num_obj + obj_id

        count = 0
        offset = 0
        slicedInds = np.where(np_im_inds == count)[0]

        label = np.array([0]*num_obj*num_obj, dtype=int)
        while(len(slicedInds) > 0):
            slice_len = len(slicedInds)
            selectInds = np.where(np_gt_rels[:,0] == count)[0]
            slicedRels = np_gt_rels[selectInds,:]
            flattenID = (slicedRels[:,1] + offset) * num_obj + (slicedRels[:,2] + offset)
            slicedLabel = slicedRels[:,3]

            label[flattenID] = slicedLabel
            
            count += 1
            offset += slice_len
            slicedInds = np.where(np_im_inds == count)[0]
        
        return Variable(torch.from_numpy(label[select_id]).long().cuda())


    def forward(self, x, im_sizes, image_offset,
                gt_boxes=None, gt_classes=None, gt_rels=None, proposals=None, train_anchor_inds=None,
                return_fmap=False):
        """
        Forward pass for detection
        :param x: Images@[batch_size, 3, IM_SIZE, IM_SIZE]
        :param im_sizes: A numpy array of (h, w, scale) for each image.
        :param image_offset: Offset onto what image we're on for MGPU training (if single GPU this is 0)
        :param gt_boxes:

        Training parameters:
        :param gt_boxes: [num_gt, 4] GT boxes over the batch.
        :param gt_classes: [num_gt, 2] gt boxes where each one is (img_id, class)
        :param train_anchor_inds: a [num_train, 2] array of indices for the anchors that will
                                  be used to compute the training loss. Each (img_ind, fpn_idx)
        :return: If train:
            scores, boxdeltas, labels, boxes, boxtargets, rpnscores, rpnboxes, rellabels
            
            if test:
            prob dists, boxes, img inds, maxscores, classes
            
        """
        result = self.detector(x, im_sizes, image_offset, gt_boxes, gt_classes, gt_rels, proposals,
                               train_anchor_inds, return_fmap=True)

        if result.is_none():
            return ValueError("heck")

        im_inds = result.im_inds - image_offset
        boxes = result.rm_box_priors

        if self.training and result.rel_labels is None:
            assert self.mode == 'sgdet'
            result.rel_labels, fg_rel_labels = rel_assignments(im_inds.data, boxes.data, result.rm_obj_labels.data,
                                                gt_boxes.data, gt_classes.data, gt_rels.data,
                                                image_offset, filter_non_overlap=True,
                                                num_sample_per_gt=1)

        #if self.training and (not self.use_rl_tree):
            # generate arbitrary forest according to graph
        #    arbitrary_forest = graph_to_trees(self.co_occour, result.rel_labels, gt_classes)
        #else:
        arbitrary_forest = None

        rel_inds = self.get_rel_inds(result.rel_labels, im_inds, boxes)

        if self.use_rl_tree:
            result.rel_label_tkh = self.get_rel_label(im_inds, gt_rels, rel_inds)

        rois = torch.cat((im_inds[:, None].float(), boxes), 1)

        result.obj_fmap = self.obj_feature_map(result.fmap.detach(), rois)

        # whole image feature, used for virtual node
        batch_size = result.fmap.shape[0]
        image_rois = Variable(torch.randn(batch_size, 5).fill_(0).cuda())
        for i in range(batch_size):
            image_rois[i, 0] = i
            image_rois[i, 1] = 0
            image_rois[i, 2] = 0
            image_rois[i, 3] = IM_SCALE
            image_rois[i, 4] = IM_SCALE
        image_fmap = self.obj_feature_map(result.fmap.detach(), image_rois)

        if self.mode != 'sgdet' and self.training:
            fg_rel_labels = result.rel_labels

        # Prevent gradients from flowing back into score_fc from elsewhere
        result.rm_obj_dists, result.obj_preds, edge_ctx, result.gen_tree_loss, result.entropy_loss, result.pair_gate, result.pair_gt = self.context(
            result.obj_fmap,
            result.rm_obj_dists.detach(),
            im_inds, result.rm_obj_labels if self.training or self.mode == 'predcls' else None,
            boxes.data, result.boxes_all, 
            arbitrary_forest,
            image_rois,
            image_fmap,
            self.co_occour,
            fg_rel_labels if self.training else None,
            x)

        if edge_ctx is None:
            edge_rep = self.post_emb(result.obj_preds)
        else:
            edge_rep = self.post_lstm(edge_ctx)

        # Split into subject and object representations
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.hidden_dim)

        subj_rep = edge_rep[:, 0]
        obj_rep = edge_rep[:, 1]

        prod_rep =  torch.cat((subj_rep[rel_inds[:, 1]], obj_rep[rel_inds[:, 2]]), 1)
        prod_rep = self.post_cat(prod_rep)

        if self.use_encoded_box:
            # encode spatial info
            assert(boxes.shape[1] == 4)
            # encoded_boxes: [box_num, (x1,y1,x2,y2,cx,cy,w,h)]
            encoded_boxes = tree_utils.get_box_info(boxes)
            # encoded_boxes_pair: [batch_szie, (box1, box2, unionbox, intersectionbox)]
            encoded_boxes_pair = tree_utils.get_box_pair_info(encoded_boxes[rel_inds[:, 1]], encoded_boxes[rel_inds[:, 2]])
            # encoded_spatial_rep
            spatial_rep = F.relu(self.encode_spatial_2(F.relu(self.encode_spatial_1(encoded_boxes_pair))))
            # element-wise multiply with prod_rep
            prod_rep = prod_rep * spatial_rep

        if self.use_vision:
            vr = self.visual_rep(result.fmap.detach(), rois, rel_inds[:, 1:])
            if self.limit_vision:
                # exact value TBD
                prod_rep = torch.cat((prod_rep[:,:2048] * vr[:,:2048], prod_rep[:,2048:]), 1)
            else:
                prod_rep = prod_rep * vr

        if self.use_tanh:
            prod_rep = F.tanh(prod_rep)

        result.rel_dists = self.rel_compress(prod_rep)

        if self.use_bias:
            result.rel_dists = result.rel_dists + self.freq_bias.index_with_labels(torch.stack((
                result.obj_preds[rel_inds[:, 1]],
                result.obj_preds[rel_inds[:, 2]],
            ), 1))

        if self.training and (not self.rl_train):
            return result

        twod_inds = arange(result.obj_preds.data) * self.num_classes + result.obj_preds.data
        result.obj_scores = F.softmax(result.rm_obj_dists, dim=1).view(-1)[twod_inds]

        # Bbox regression
        if self.mode == 'sgdet':
            bboxes = result.boxes_all.view(-1, 4)[twod_inds].view(result.boxes_all.size(0), 4)
        else:
            # Boxes will get fixed by filter_dets function.
            bboxes = result.rm_box_priors

        rel_rep = F.softmax(result.rel_dists, dim=1)

        if not self.rl_train:
            return filter_dets(bboxes, result.obj_scores,
                           result.obj_preds, rel_inds[:, 1:], rel_rep, gt_boxes, gt_classes, gt_rels)
        else:
            return result, filter_dets(bboxes, result.obj_scores, result.obj_preds, rel_inds[:, 1:], rel_rep, gt_boxes, gt_classes, gt_rels)

    def __getitem__(self, batch):
        """ Hack to do multi-GPU training"""
        batch.scatter()
        if self.num_gpus == 1:
            return self(*batch[0])

        replicas = nn.parallel.replicate(self, devices=list(range(self.num_gpus)))
        outputs = nn.parallel.parallel_apply(replicas, [batch[i] for i in range(self.num_gpus)])

        if self.training:
            return gather_res(outputs, 0, dim=0)
        return outputs
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=False,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=4096,
                 use_resnet=False,
                 thresh=0.01,
                 use_proposals=False,
                 use_bias=True,
                 limit_vision=True,
                 depth_model=None,
                 pretrained_depth=False,
                 active_features=None,
                 frozen_features=None,
                 use_embed=False,
                 **kwargs):
        """
        :param classes: object classes
        :param rel_classes: relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param use_vision: enable the contribution of union of bounding boxes
        :param require_overlap_det: whether two objects must intersect
        :param embed_dim: word2vec embeddings dimension
        :param hidden_dim: dimension of the fusion hidden layer
        :param use_resnet: use resnet as faster-rcnn's backbone
        :param thresh: faster-rcnn related threshold (Threshold for calling it a good box)
        :param use_proposals: whether to use region proposal candidates
        :param use_bias: enable frequency bias
        :param limit_vision: use truncated version of UoBB features
        :param depth_model: provided architecture for depth feature extraction
        :param pretrained_depth: whether the depth feature extractor should be initialized with ImageNet weights
        :param active_features: what set of features should be enabled (e.g. 'vdl' : visual, depth, and location features)
        :param frozen_features: what set of features should be frozen (e.g. 'd' : depth)
        :param use_embed: use word2vec embeddings
        """
        RelModelBase.__init__(self, classes, rel_classes, mode, num_gpus,
                              require_overlap_det, active_features,
                              frozen_features)
        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.use_vision = use_vision
        self.use_bias = use_bias
        self.limit_vision = limit_vision

        # -- Store depth related parameters
        assert depth_model in DEPTH_MODELS
        self.depth_model = depth_model
        self.pretrained_depth = pretrained_depth
        self.depth_pooling_dim = DEPTH_DIMS[self.depth_model]
        self.use_embed = use_embed
        self.detector = nn.Module()
        features_size = 0

        # -- Check whether ResNet is selected as faster-rcnn's backbone
        if use_resnet:
            raise ValueError(
                "The current model does not support ResNet as the Faster-RCNN's backbone."
            )
        """ *** DIFFERENT COMPONENTS OF THE PROPOSED ARCHITECTURE *** 
        This is the part where the different components of the proposed relation detection 
        architecture are defined. In the case of RGB images, we have class probability distribution
        features, visual features, and the location ones. If we are considering depth images as well,
        we augment depth features too. """

        # -- Visual features
        if self.has_visual:
            # -- Define faster R-CNN network and it's related feature extractors
            self.detector = ObjectDetector(
                classes=classes,
                mode=('proposals' if use_proposals else 'refinerels')
                if mode == 'sgdet' else 'gtbox',
                use_resnet=use_resnet,
                thresh=thresh,
                max_per_img=64,
            )
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

            # -- Define union features
            if self.use_vision:
                # -- UoBB pooling module
                self.union_boxes = UnionBoxesAndFeats(
                    pooling_size=self.pooling_size,
                    stride=16,
                    dim=1024 if use_resnet else 512)

                # -- UoBB feature extractor
                roi_fmap = [
                    Flattener(),
                    load_vgg(use_dropout=False,
                             use_relu=False,
                             use_linear=self.hidden_dim == 4096,
                             pretrained=False).classifier,
                ]
                if self.hidden_dim != 4096:
                    roi_fmap.append(nn.Linear(4096, self.hidden_dim))
                self.roi_fmap = nn.Sequential(*roi_fmap)

            # -- Define visual features hidden layer
            self.visual_hlayer = nn.Sequential(*[
                xavier_init(nn.Linear(self.obj_dim * 2, self.FC_SIZE_VISUAL)),
                nn.ReLU(inplace=True),
                nn.Dropout(0.8)
            ])
            self.visual_scale = ScaleLayer(1.0)
            features_size += self.FC_SIZE_VISUAL

        # -- Location features
        if self.has_loc:
            # -- Define location features hidden layer
            self.location_hlayer = nn.Sequential(*[
                xavier_init(nn.Linear(self.LOC_INPUT_SIZE, self.FC_SIZE_LOC)),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1)
            ])
            self.location_scale = ScaleLayer(1.0)
            features_size += self.FC_SIZE_LOC

        # -- Class features
        if self.has_class:
            if self.use_embed:
                # -- Define class embeddings
                embed_vecs = obj_edge_vectors(self.classes,
                                              wv_dim=self.embed_dim)
                self.obj_embed = nn.Embedding(self.num_classes, self.embed_dim)
                self.obj_embed.weight.data = embed_vecs.clone()

            classme_input_dim = self.embed_dim if self.use_embed else self.num_classes
            # -- Define Class features hidden layer
            self.classme_hlayer = nn.Sequential(*[
                xavier_init(
                    nn.Linear(classme_input_dim * 2, self.FC_SIZE_CLASS)),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1)
            ])
            self.classme_scale = ScaleLayer(1.0)
            features_size += self.FC_SIZE_CLASS

        # -- Depth features
        if self.has_depth:
            # -- Initialize depth backbone
            self.depth_backbone = DepthCNN(depth_model=self.depth_model,
                                           pretrained=self.pretrained_depth)

            # -- Create a relation head which is used to carry on the feature extraction
            # from RoIs of depth features
            self.depth_rel_head = self.depth_backbone.get_classifier()

            # -- Define depth features hidden layer
            self.depth_rel_hlayer = nn.Sequential(*[
                xavier_init(
                    nn.Linear(self.depth_pooling_dim * 2, self.FC_SIZE_DEPTH)),
                nn.ReLU(inplace=True),
                nn.Dropout(0.6),
            ])
            self.depth_scale = ScaleLayer(1.0)
            features_size += self.FC_SIZE_DEPTH

        # -- Initialize frequency bias if needed
        if self.use_bias:
            self.freq_bias = FrequencyBias()

        # -- *** Fusion layer *** --
        # -- A hidden layer for concatenated features (fusion features)
        self.fusion_hlayer = nn.Sequential(*[
            xavier_init(nn.Linear(features_size, self.hidden_dim)),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1)
        ])

        # -- Final FC layer which predicts the relations
        self.rel_out = xavier_init(
            nn.Linear(self.hidden_dim, self.num_rels, bias=True))

        # -- Freeze the user specified features
        if self.frz_visual:
            self.freeze_module(self.detector)
            self.freeze_module(self.roi_fmap_obj)
            self.freeze_module(self.visual_hlayer)
            if self.use_vision:
                self.freeze_module(self.roi_fmap)
                self.freeze_module(self.union_boxes.conv)

        if self.frz_class:
            self.freeze_module(self.classme_hlayer)

        if self.frz_loc:
            self.freeze_module(self.location_hlayer)

        if self.frz_depth:
            self.freeze_module(self.depth_backbone)
            self.freeze_module(self.depth_rel_head)
            self.freeze_module(self.depth_rel_hlayer)
Пример #24
0
class DecoderRNN(torch.nn.Module):
    def __init__(self,
                 classes,
                 rel_classes,
                 embed_dim,
                 obj_dim,
                 inputs_dim,
                 hidden_dim,
                 pooling_dim,
                 recurrent_dropout_probability=0.2,
                 use_highway=True,
                 use_input_projection_bias=True,
                 use_vision=True,
                 use_bias=True,
                 use_tanh=True,
                 limit_vision=True,
                 sl_pretrain=False,
                 num_iter=-1):
        """
        Initializes the RNN
        :param embed_dim: Dimension of the embeddings
        :param encoder_hidden_dim: Hidden dim of the encoder, for attention purposes
        :param hidden_dim: Hidden dim of the decoder
        :param vocab_size: Number of words in the vocab
        :param bos_token: To use during decoding (non teacher forcing mode))
        :param bos: beginning of sentence token
        :param unk: unknown token (not used)
        """
        super(DecoderRNN, self).__init__()

        self.rel_embedding_dim = 100
        self.classes = classes
        self.rel_classes = rel_classes
        embed_vecs = obj_edge_vectors(['start'] + self.classes, wv_dim=100)
        self.obj_embed = nn.Embedding(len(self.classes), embed_dim)
        self.obj_embed.weight.data = embed_vecs

        embed_rels = obj_edge_vectors(self.rel_classes,
                                      wv_dim=self.rel_embedding_dim)
        self.rel_embed = nn.Embedding(len(self.rel_classes),
                                      self.rel_embedding_dim)
        self.rel_embed.weight.data = embed_rels

        self.embed_dim = embed_dim
        self.obj_dim = obj_dim
        self.hidden_size = hidden_dim
        self.inputs_dim = inputs_dim
        self.pooling_dim = pooling_dim
        self.nms_thresh = 0.3

        self.use_vision = use_vision
        self.use_bias = use_bias
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.sl_pretrain = sl_pretrain
        self.num_iter = num_iter

        self.recurrent_dropout_probability = recurrent_dropout_probability
        self.use_highway = use_highway
        # We do the projections for all the gates all at once, so if we are
        # using highway layers, we need some extra projections, which is
        # why the sizes of the Linear layers change here depending on this flag.
        if use_highway:
            self.input_linearity = torch.nn.Linear(
                self.input_size,
                6 * self.hidden_size,
                bias=use_input_projection_bias)
            self.state_linearity = torch.nn.Linear(self.hidden_size,
                                                   5 * self.hidden_size,
                                                   bias=True)
        else:
            self.input_linearity = torch.nn.Linear(
                self.input_size,
                4 * self.hidden_size,
                bias=use_input_projection_bias)
            self.state_linearity = torch.nn.Linear(self.hidden_size,
                                                   4 * self.hidden_size,
                                                   bias=True)

        # self.obj_in_lin = torch.nn.Linear(self.rel_embedding_dim, self.rel_embedding_dim, bias=True)

        self.out = nn.Linear(self.hidden_size, len(self.classes))
        self.reset_parameters()

        # For relation predication
        embed_vecs2 = obj_edge_vectors(self.classes, wv_dim=embed_dim)
        self.obj_embed2 = nn.Embedding(self.num_classes, embed_dim)
        self.obj_embed2.weight.data = embed_vecs2.clone()

        # self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2)
        self.post_lstm = nn.Linear(self.obj_dim + 2 * self.embed_dim + 128,
                                   self.pooling_dim * 2)
        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.
        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(
            0, 10.0 * math.sqrt(1.0 / self.hidden_size)
        )  ######## there may need more consideration
        self.post_lstm.bias.data.zero_()

        self.rel_compress = nn.Linear(self.pooling_dim,
                                      self.num_rels,
                                      bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()

            # simple relation model
            from dataloaders.visual_genome import VG
            from lib.get_dataset_counts import get_counts, box_filter
            fg_matrix, bg_matrix = get_counts(train_data=VG.splits(
                num_val_im=5000,
                filter_non_overlap=True,
                filter_duplicate_rels=True,
                use_proposals=False)[0],
                                              must_overlap=True)
            prob_matrix = fg_matrix.astype(np.float32)
            prob_matrix[:, :, 0] = bg_matrix

            # TRYING SOMETHING NEW.
            prob_matrix[:, :, 0] += 1
            prob_matrix /= np.sum(prob_matrix, 2)[:, :, None]
            # prob_matrix /= float(fg_matrix.max())

            prob_matrix[:, :, 0] = 0  # Zero out BG
            self.prob_matrix = prob_matrix

    @property
    def num_classes(self):
        return len(self.classes)

    @property
    def num_rels(self):
        return len(self.rel_classes)

    @property
    def input_size(self):
        return self.inputs_dim + self.obj_embed.weight.size(1)

    def reset_parameters(self):
        # Use sensible default initializations for parameters.
        block_orthogonal(self.input_linearity.weight.data,
                         [self.hidden_size, self.input_size])
        block_orthogonal(self.state_linearity.weight.data,
                         [self.hidden_size, self.hidden_size])

        self.state_linearity.bias.data.fill_(0.0)
        # Initialize forget gate biases to 1.0 as per An Empirical
        # Exploration of Recurrent Network Architectures, (Jozefowicz, 2015).
        self.state_linearity.bias.data[self.hidden_size:2 *
                                       self.hidden_size].fill_(1.0)

    def lstm_equations(self,
                       timestep_input,
                       previous_state,
                       previous_memory,
                       dropout_mask=None):
        """
        Does the hairy LSTM math
        :param timestep_input:
        :param previous_state:
        :param previous_memory:
        :param dropout_mask:
        :return:
        """
        # Do the projections for all the gates all at once.
        projected_input = self.input_linearity(timestep_input)
        projected_state = self.state_linearity(previous_state)

        # Main LSTM equations using relevant chunks of the big linear
        # projections of the hidden state and inputs.
        input_gate = torch.sigmoid(
            projected_input[:, 0 * self.hidden_size:1 * self.hidden_size] +
            projected_state[:, 0 * self.hidden_size:1 * self.hidden_size])
        forget_gate = torch.sigmoid(
            projected_input[:, 1 * self.hidden_size:2 * self.hidden_size] +
            projected_state[:, 1 * self.hidden_size:2 * self.hidden_size])
        memory_init = torch.tanh(
            projected_input[:, 2 * self.hidden_size:3 * self.hidden_size] +
            projected_state[:, 2 * self.hidden_size:3 * self.hidden_size])
        output_gate = torch.sigmoid(
            projected_input[:, 3 * self.hidden_size:4 * self.hidden_size] +
            projected_state[:, 3 * self.hidden_size:4 * self.hidden_size])
        memory = input_gate * memory_init + forget_gate * previous_memory
        timestep_output = output_gate * torch.tanh(memory)

        if self.use_highway:
            highway_gate = torch.sigmoid(
                projected_input[:, 4 * self.hidden_size:5 * self.hidden_size] +
                projected_state[:, 4 * self.hidden_size:5 * self.hidden_size])
            highway_input_projection = projected_input[:,
                                                       5 * self.hidden_size:6 *
                                                       self.hidden_size]
            timestep_output = highway_gate * timestep_output + (
                1 - highway_gate) * highway_input_projection

        # Only do dropout if the dropout prob is > 0.0 and we are in training mode.
        if dropout_mask is not None and self.training:
            timestep_output = timestep_output * dropout_mask
        return timestep_output, memory

    def get_rel_dist(self, obj_preds, obj_feats, rel_inds, vr=None):
        obj_embed2 = self.obj_embed2(obj_preds)
        edge_ctx = torch.cat((obj_embed2, obj_feats), 1)

        edge_rep = self.post_lstm(edge_ctx)
        edge_rep = edge_rep.view(edge_rep.size(0), 2, self.pooling_dim)

        subj_rep = edge_rep[:, 0]
        obj_rep = edge_rep[:, 1]

        prod_rep = subj_rep[rel_inds[:, 1]] * obj_rep[rel_inds[:, 2]]

        if self.use_vision:
            if self.limit_vision:
                # exact value TBD
                prod_rep = torch.cat(
                    (prod_rep[:, :2048] * vr[:, :2048], prod_rep[:, 2048:]), 1)
            else:
                prod_rep = prod_rep * vr

        if self.use_tanh:
            prod_rep = F.tanh(prod_rep)

        rel_dists = self.rel_compress(prod_rep)

        if self.use_bias:
            rel_dists = rel_dists + self.freq_bias.index_with_labels(
                torch.stack((
                    obj_preds[rel_inds[:, 1]],
                    obj_preds[rel_inds[:, 2]],
                ), 1))

        return rel_dists

    def get_freq_rel_dist(self, obj_preds, rel_inds):
        """
        Baseline: relation model
        """
        rel_dists = self.freq_bias.index_with_labels(
            torch.stack((
                obj_preds[rel_inds[:, 1]],
                obj_preds[rel_inds[:, 2]],
            ), 1))

        return rel_dists

    def get_simple_rel_dist(self, obj_preds, rel_inds):

        obj_preds_np = obj_preds.cpu().numpy()
        rel_inds_np = rel_inds.cpu().numpy()

        rel_dists_list = []
        o1o2 = obj_preds_np[rel_inds_np][:, 1:]
        for o1, o2 in o1o2:
            rel_dists_list.append(self.prob_matrix[o1, o2])

        assert len(rel_dists_list) == len(rel_inds)
        return Variable(
            torch.from_numpy(np.array(rel_dists_list)).cuda(
                obj_preds.get_device())
        )  # there is no gradient for this type of code

    def forward(
            self,  # pylint: disable=arguments-differ
            # inputs: PackedSequence,
        sequence_tensor,
            rel_inds,
            initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
            labels=None,
            boxes_for_nms=None,
            vr=None):

        # get the relations for each object
        # numer = torch.arange(0, rel_inds.size(0)).long().cuda(rel_inds.get_device())

        # objs_to_outrels = sequence_tensor.data.new(sequence_tensor.size(0),
        #                                             rel_inds.size(0)).zero_()
        # objs_to_outrels.view(-1)[rel_inds[:, 1] * rel_inds.size(0) + numer] = 1
        # objs_to_outrels = Variable(objs_to_outrels)

        # objs_to_inrels = sequence_tensor.data.new(sequence_tensor.size(0), rel_inds.size(0)).zero_()
        # objs_to_inrels.view(-1)[rel_inds[:, 2] * rel_inds.size(0) + numer] = 1
        # # average the relations for each object, and add "non relation" to the one with on relation communication
        # # test8 / test10 need comments
        # objs_to_inrels = objs_to_inrels / (objs_to_inrels.sum(1) + 1e-8)[:, None]
        # objs_to_inrels = Variable(objs_to_inrels)

        batch_size = sequence_tensor.size(0)

        # We're just doing an LSTM decoder here so ignore states, etc
        if initial_state is None:
            previous_memory = Variable(sequence_tensor.data.new().resize_(
                batch_size, self.hidden_size).fill_(0))
            previous_state = Variable(sequence_tensor.data.new().resize_(
                batch_size, self.hidden_size).fill_(0))
        else:
            assert len(initial_state) == 2
            previous_state = initial_state[0].squeeze(0)
            previous_memory = initial_state[1].squeeze(0)

        # 'start'
        previous_embed = self.obj_embed.weight[0, None].expand(batch_size, 100)

        # previous_comm_info = Variable(sequence_tensor.data.new()
        #                                             .resize_(batch_size, 100).fill_(0))

        if self.recurrent_dropout_probability > 0.0:
            dropout_mask = get_dropout_mask(self.recurrent_dropout_probability,
                                            previous_memory)
        else:
            dropout_mask = None

        # Only accumulating label predictions here, discarding everything else
        out_dists_list = []
        out_commitments_list = []

        end_ind = 0
        for i in range(self.num_iter):

            # timestep_input = torch.cat((sequence_tensor, previous_embed, previous_comm_info), 1)
            timestep_input = torch.cat((sequence_tensor, previous_embed), 1)

            previous_state, previous_memory = self.lstm_equations(
                timestep_input,
                previous_state,
                previous_memory,
                dropout_mask=dropout_mask)

            pred_dist = self.out(previous_state)
            out_dists_list.append(pred_dist)

            # if self.training:
            #     labels_to_embed = labels.clone()
            #     # Whenever labels are 0 set input to be our max prediction
            #     nonzero_pred = pred_dist[:, 1:].max(1)[1] + 1
            #     is_bg = (labels_to_embed.data == 0).nonzero()
            #     if is_bg.dim() > 0:
            #         labels_to_embed[is_bg.squeeze(1)] = nonzero_pred[is_bg.squeeze(1)]
            #     out_commitments_list.append(labels_to_embed)
            #     previous_embed = self.obj_embed(labels_to_embed+1)
            # else:
            #     out_dist_sample = F.softmax(pred_dist, dim=1)
            #     # if boxes_for_nms is not None:
            #     #     out_dist_sample[domains_allowed[i] == 0] = 0.0

            #     # Greedily take the max here amongst non-bgs
            #     best_ind = out_dist_sample[:, 1:].max(1)[1] + 1

            #     # if boxes_for_nms is not None and i < boxes_for_nms.size(0):
            #     #     best_int = int(best_ind.data[0])
            #     #     domains_allowed[i:, best_int] *= (1 - is_overlap[i, i:, best_int])
            #     out_commitments_list.append(best_ind)
            #     previous_embed = self.obj_embed(best_ind+1)
            if self.training and (not self.sl_pretrain):
                import pdb
                pdb.set_trace()
                out_dist_sample = F.softmax(pred_dist, dim=1)
                sample_ind = out_dist_sample[:, 1:].multinomial(
                    1)[:, 0] + 1  # sampling at training stage
                out_commitments_list.append(sample_ind)
                previous_embed = self.obj_embed(sample_ind + 1)
            else:
                out_dist_sample = F.softmax(pred_dist, dim=1)
                # best_ind = out_dist_sample[:, 1:].max(1)[1] + 1
                # debug
                best_ind = out_dist_sample.max(1)[
                    1]  ###########################
                out_commitments_list.append(best_ind)
                previous_embed = self.obj_embed(best_ind + 1)

            # calculate communicate information
            # rel_dists = self.get_rel_dist(best_ind, sequence_tensor, rel_inds, vr)
            # all_comm_info = rel_dists @ self.rel_embed.weight

            # obj_rel_weights = sequence_tensor @ torch.transpose(self.obj_rel_att.weight, 1, 0) @ torch.transpose(all_comm_info, 1, 0)
            # masked_objs_to_inrels = obj_rel_weights * objs_to_inrels
            # objs_to_inrels = masked_objs_to_inrels / (masked_objs_to_inrels.sum(1) + 1e-8)[:, None]

            # previous_comm_info = self.obj_in_lin(objs_to_inrels @ all_comm_info)

        out_dists = out_dists_list[-1]
        out_commitments = out_commitments_list[-1]
        # Do NMS here as a post-processing step
        """
        if boxes_for_nms is not None and not self.training:
            is_overlap = nms_overlaps(boxes_for_nms.data).view(
                boxes_for_nms.size(0), boxes_for_nms.size(0), boxes_for_nms.size(1)
            ).cpu().numpy() >= self.nms_thresh
            # is_overlap[np.arange(boxes_for_nms.size(0)), np.arange(boxes_for_nms.size(0))] = False
            out_dists_sampled = F.softmax(out_dists).data.cpu().numpy()
            out_dists_sampled[:,0] = -1.0 # change 0.0 to 1.0 for the bug when the score for bg is almost 1.
            out_commitments = out_commitments.data.new(len(out_commitments)).fill_(0)
            for i in range(out_commitments.size(0)):
                box_ind, cls_ind = np.unravel_index(out_dists_sampled.argmax(), out_dists_sampled.shape)
                out_commitments[int(box_ind)] = int(cls_ind)
                out_dists_sampled[is_overlap[box_ind,:,cls_ind], cls_ind] = -1.0 #0.0
                out_dists_sampled[box_ind] = -1.0 # This way we won't re-sample
            out_commitments = Variable(out_commitments)
        """
        # rel_dists = self.get_rel_dist(out_commitments, sequence_tensor, rel_inds, vr)
        # simple model
        # import pdb; pdb.set_trace()
        # rel_dists = self.get_freq_rel_dist(out_commitments, rel_inds)

        rel_dists = self.get_simple_rel_dist(out_commitments.data, rel_inds)

        return out_dists_list, out_commitments_list, None, \
                    out_dists, out_commitments, rel_dists