Example #1
0
    def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, use_vision=True, require_overlap_det=True,
                 embed_dim=200, hidden_dim=256, pooling_dim=2048,
                 nl_obj=1, nl_edge=2, use_resnet=False, order='confidence', thresh=0.01,
                 use_proposals=False, pass_in_obj_feats_to_decoder=True,
                 pass_in_obj_feats_to_edge=True, rec_dropout=0.0, use_bias=True, use_tanh=True,
                 limit_vision=True):

        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param use_vision: Whether to use vision in the final product
        :param require_overlap_det: Whether two objects must intersect
        :param embed_dim: Dimension for all embeddings
        :param hidden_dim: LSTM hidden size
        :param obj_dim:
        """
        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision=limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'
        # print('REL MODEL CONSTRUCTOR: 1')
        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )
        # print('REL MODEL CONSTRUCTOR: 2')
        self.context = LinearizedContext(self.classes, self.rel_classes, mode=self.mode,
                                         embed_dim=self.embed_dim, hidden_dim=self.hidden_dim,
                                         obj_dim=self.obj_dim,
                                         nl_obj=nl_obj, nl_edge=nl_edge, dropout_rate=rec_dropout,
                                         order=order,
                                         pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
                                         pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge)
        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16,
                                              dim=1024 if use_resnet else 512)
        # print('REL MODEL CONSTRUCTOR: 3')
        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier
        # print('REL MODEL CONSTRUCTOR: 4')
        ###################################
        self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2)

        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.

        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(0, 10.0 * math.sqrt(1.0 / self.hidden_dim))
        self.post_lstm.bias.data.zero_()
        # print('REL MODEL CONSTRUCTOR: 5')
        if nl_edge == 0:
            self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim*2)
            self.post_emb.weight.data.normal_(0, math.sqrt(1.0))

        self.rel_compress = nn.Linear(self.pooling_dim, self.num_rels, bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()
Example #2
0
    def __init__(self, classes, rel_classes, graph_path, emb_path, mode='sgdet', num_gpus=1, 
                 require_overlap_det=True, pooling_dim=4096, use_resnet=False, thresh=0.01,
                 use_proposals=False,
                 ggnn_rel_time_step_num=3,
                 ggnn_rel_hidden_dim=512,
                 ggnn_rel_output_dim=512, use_knowledge=True, use_embedding=True, refine_obj_cls=False,
                 rel_counts_path=None, class_volume=1.0, top_k_to_keep=5, normalize_messages=True):

        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param require_overlap_det: Whether two objects must intersect
        """
        super(KERN, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode
        self.pooling_size = 7
        self.obj_dim = 2048 if use_resnet else 4096
        self.rel_dim = self.obj_dim
        self.pooling_dim = pooling_dim

        self.require_overlap = require_overlap_det and self.mode == 'sgdet'

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64
        )


        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16,
                                              dim=1024 if use_resnet else 512)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False, use_relu=False, use_linear=pooling_dim == 4096, pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        self.ggnn_rel_reason = GGNNRelReason(mode=self.mode, 
                                             num_obj_cls=len(self.classes), 
                                             num_rel_cls=len(rel_classes), 
                                             obj_dim=self.obj_dim, 
                                             rel_dim=self.rel_dim, 
                                             time_step_num=ggnn_rel_time_step_num, 
                                             hidden_dim=ggnn_rel_hidden_dim, 
                                             output_dim=ggnn_rel_output_dim,
                                             emb_path=emb_path,
                                             graph_path=graph_path, 
                                             refine_obj_cls=refine_obj_cls, 
                                             use_knowledge=use_knowledge, 
                                             use_embedding=use_embedding,
                                             top_k_to_keep=top_k_to_keep,
                                             normalize_messages=normalize_messages
                                            )

        if rel_counts_path is not None:
            with open(rel_counts_path, 'rb') as fin:
                rel_counts = pickle.load(fin)
            beta = (class_volume - 1.0) / class_volume
            self.rel_class_weights = (1.0 - beta) / (1 - (beta ** rel_counts))
            self.rel_class_weights *= float(self.num_rels) / np.sum(self.rel_class_weights)
        else:
            self.rel_class_weights = np.ones((self.num_rels,))
        self.rel_class_weights = Variable(torch.from_numpy(self.rel_class_weights).float().cuda(), requires_grad=False)
Example #3
0
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 require_overlap_det=True,
                 pooling_dim=4096,
                 use_resnet=False,
                 thresh=0.01,
                 use_proposals=False,
                 use_ggnn_obj=False,
                 ggnn_obj_time_step_num=3,
                 ggnn_obj_hidden_dim=512,
                 ggnn_obj_output_dim=512,
                 use_ggnn_rel=False,
                 ggnn_rel_time_step_num=3,
                 ggnn_rel_hidden_dim=512,
                 ggnn_rel_output_dim=512,
                 use_obj_knowledge=True,
                 use_rel_knowledge=True,
                 obj_knowledge='',
                 rel_knowledge=''):
        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param require_overlap_det: Whether two objects must intersect
        """
        super(KERN, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode
        self.pooling_size = 7
        self.obj_dim = 2048 if use_resnet else 4096
        self.rel_dim = self.obj_dim
        self.pooling_dim = pooling_dim

        self.use_ggnn_obj = use_ggnn_obj
        self.use_ggnn_rel = use_ggnn_rel

        self.require_overlap = require_overlap_det and self.mode == 'sgdet'

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels')
            if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64)

        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False,
                         use_relu=False,
                         use_linear=pooling_dim == 4096,
                         pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        if self.use_ggnn_obj:
            self.ggnn_obj_reason = GGNNObjReason(
                mode=self.mode,
                num_obj_cls=len(self.classes),
                obj_dim=self.obj_dim,
                time_step_num=ggnn_obj_time_step_num,
                hidden_dim=ggnn_obj_hidden_dim,
                output_dim=ggnn_obj_output_dim,
                use_knowledge=use_obj_knowledge,
                knowledge_matrix=obj_knowledge)

        if self.use_ggnn_rel:
            self.ggnn_rel_reason = GGNNRelReason(
                mode=self.mode,
                num_obj_cls=len(self.classes),
                num_rel_cls=len(rel_classes),
                obj_dim=self.obj_dim,
                rel_dim=self.rel_dim,
                time_step_num=ggnn_rel_time_step_num,
                hidden_dim=ggnn_rel_hidden_dim,
                output_dim=ggnn_obj_output_dim,
                use_knowledge=use_rel_knowledge,
                knowledge_matrix=rel_knowledge)
        else:
            self.vr_fc_cls = VRFC(self.mode, self.rel_dim, len(self.classes),
                                  len(self.rel_classes))
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=256,
                 pooling_dim=4096,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=False,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 limit_vision=True):
        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param use_vision: Whether to use vision in the final product
        :param require_overlap_det: Whether two objects must intersect
        :param embed_dim: Dimension for all embeddings
        :param hidden_dim: LSTM hidden size
        :param obj_dim:
        """
        super(RelModelLinknet, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.ctx_dim = 1024 if use_resnet else 512
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels')
            if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )

        self.context = LinearizedContext(self.classes,
                                         self.rel_classes,
                                         mode=self.mode,
                                         embed_dim=self.embed_dim,
                                         hidden_dim=self.hidden_dim,
                                         obj_dim=self.obj_dim,
                                         pooling_dim=self.pooling_dim,
                                         ctx_dim=self.ctx_dim)

        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False,
                         use_relu=False,
                         use_linear=pooling_dim == 4096,
                         pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        # Global Context Encoding
        self.GCE = GlobalContextEncoding(num_classes=self.num_classes,
                                         ctx_dim=self.ctx_dim)

        ###################################

        # K2
        self.pos_embed = nn.Sequential(*[
            nn.BatchNorm1d(4, momentum=BATCHNORM_MOMENTUM / 10.0),
            nn.Linear(4, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
        ])

        # fc4
        self.rel_compress = nn.Linear(self.pooling_dim + 128,
                                      self.num_rels,
                                      bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)

        if self.use_bias:
            self.freq_bias = FrequencyBias()
Example #5
0
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=256,
                 pooling_dim=2048,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=False,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 gnn=True,
                 reachability=False,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 limit_vision=True):
        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param use_vision: Whether to use vision in the final product
        :param require_overlap_det: Whether two objects must intersect
        :param embed_dim: Dimension for all embeddings
        :param hidden_dim: LSTM hidden size
        :param obj_dim:
        """
        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode
        self.reachability = reachability
        self.gnn = gnn
        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'
        self.global_embedding = EmbeddingImagenet(4096)
        self.global_logist = nn.Linear(4096, 151,
                                       bias=True)  # CosineLinear(4096,150)#
        self.global_logist.weight = torch.nn.init.xavier_normal(
            self.global_logist.weight, gain=1.0)

        self.disc_center = DiscCentroidsLoss(self.num_rels,
                                             self.pooling_dim + 256)
        self.meta_classify = MetaEmbedding_Classifier(
            feat_dim=self.pooling_dim + 256, num_classes=self.num_rels)

        # self.global_rel_logist = nn.Linear(4096, 50 , bias=True)
        # self.global_rel_logist.weight = torch.nn.init.xavier_normal(self.global_rel_logist.weight, gain=1.0)

        # self.global_logist = CosineLinear(4096,150)
        self.global_sub_additive = nn.Linear(4096, 1, bias=True)
        self.global_obj_additive = nn.Linear(4096, 1, bias=True)

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels')
            if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )

        self.context = LinearizedContext(
            self.classes,
            self.rel_classes,
            mode=self.mode,
            embed_dim=self.embed_dim,
            hidden_dim=self.hidden_dim,
            obj_dim=self.obj_dim,
            nl_obj=nl_obj,
            nl_edge=nl_edge,
            dropout_rate=rec_dropout,
            order=order,
            pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
            pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge)

        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False,
                         use_relu=False,
                         use_linear=pooling_dim == 4096,
                         pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        ###################################
        self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2)

        self.edge_coordinate_embedding = nn.Sequential(*[
            nn.BatchNorm1d(5, momentum=BATCHNORM_MOMENTUM / 10.0),
            nn.Linear(5, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
        ])
        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.

        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(
            0, 10.0 * math.sqrt(1.0 / self.hidden_dim))
        self.post_lstm.bias.data.zero_()

        if nl_edge == 0:
            self.post_emb = nn.Embedding(self.num_classes,
                                         self.pooling_dim * 2)
            self.post_emb.weight.data.normal_(0, math.sqrt(1.0))

        self.rel_compress = nn.Linear(4096 + 256, 51, bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)

        self.node_transform = nn.Linear(4096, 256, bias=True)
        self.edge_transform = nn.Linear(4096, 256, bias=True)
        # self.rel_compress = CosineLinear(self.pooling_dim+256, self.num_rels)
        # self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()
        if self.gnn:
            self.graph_network_node = GraphNetwork(4096)
            self.graph_network_edge = GraphNetwork()
            if self.training:
                self.graph_network_node.train()
                self.graph_network_edge.train()
            else:
                self.graph_network_node.eval()
                self.graph_network_edge.eval()
        self.edge_sim_network = nn.Linear(4096, 1, bias=True)
        self.metric_net = MetricLearning()
Example #6
0
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=False,
                 embed_dim=200,
                 hidden_dim=256,
                 obj_dim=2048,
                 pooling_dim=4096,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=True,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 limit_vision=True,
                 spatial_dim=128,
                 mp_iter_num=1,
                 trim_graph=True):
        """
        Args:
            mp_iter_num: integer, number of message passing iteration
            trim_graph: boolean, trim graph in rel pn
        """
        super(FckModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = obj_dim
        self.pooling_dim = 2048 if use_resnet else 4096
        self.spatial_dim = spatial_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'

        self.mp_iter_num = mp_iter_num
        self.trim_graph = trim_graph

        classes_word_vec = obj_edge_vectors(self.classes, wv_dim=embed_dim)
        self.classes_word_embedding = nn.Embedding(self.num_classes, embed_dim)
        self.classes_word_embedding.weight.data = classes_word_vec.clone()
        self.classes_word_embedding.weight.requires_grad = False

        #fg_matrix, bg_matrix = get_counts()
        #rel_obj_distribution = fg_matrix / (fg_matrix.sum(2)[:, :, None] + 1e-5)
        #rel_obj_distribution = torch.FloatTensor(rel_obj_distribution)
        #rel_obj_distribution = rel_obj_distribution.view(-1, self.num_rels)
        #
        #self.rel_obj_distribution = nn.Embedding(rel_obj_distribution.size(0), self.num_rels)
        ## (#obj_class * #obj_class, #rel_class)
        #self.rel_obj_distribution.weight.data = rel_obj_distribution

        if mode == 'sgdet':
            if use_proposals:
                obj_detector_mode = 'proposals'
            else:
                obj_detector_mode = 'refinerels'
        else:
            obj_detector_mode = 'gtbox'

        self.detector = ObjectDetector(
            classes=classes,
            mode=obj_detector_mode,
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )

        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512,
                                              use_feats=False)
        self.spatial_fc = nn.Sequential(*[
            nn.Linear(4, spatial_dim),
            nn.BatchNorm1d(spatial_dim, momentum=BATCHNORM_MOMENTUM / 10.),
            nn.ReLU(inplace=True)
        ])
        self.word_fc = nn.Sequential(*[
            nn.Linear(2 * embed_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim, momentum=BATCHNORM_MOMENTUM / 10.),
            nn.ReLU(inplace=True)
        ])
        # union box feats
        feats_dim = obj_dim + spatial_dim + hidden_dim
        self.relpn_fc = nn.Linear(feats_dim, 2)
        self.relcnn_fc1 = nn.Sequential(
            *[nn.Linear(feats_dim, feats_dim),
              nn.ReLU(inplace=True)])

        # v2 model---------
        self.box_mp_fc = nn.Sequential(*[
            nn.Linear(obj_dim, obj_dim),
        ])
        self.sub_rel_mp_fc = nn.Sequential(*[nn.Linear(feats_dim, obj_dim)])

        self.obj_rel_mp_fc = nn.Sequential(*[
            nn.Linear(feats_dim, obj_dim),
        ])

        self.mp_atten_fc = nn.Sequential(*[
            nn.Linear(feats_dim + obj_dim, obj_dim),
            nn.ReLU(inplace=True),
            nn.Linear(obj_dim, 1)
        ])
        # v2 model----------

        self.cls_fc = nn.Linear(obj_dim, self.num_classes)

        self.relcnn_fc2 = nn.Linear(feats_dim, self.num_rels)

        # v3 model -----------

        self.mem_module = MemoryRNN(classes=classes,
                                    rel_classes=rel_classes,
                                    inputs_dim=feats_dim,
                                    hidden_dim=hidden_dim,
                                    recurrent_dropout_probability=.0)
        # v3 model -----------

        if use_resnet:
            # deprecate
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                load_vgg(
                    use_dropout=False,
                    use_relu=False,
                    use_linear=self.obj_dim == 4096,
                    pretrained=False,
                ).classifier,
                nn.Linear(self.pooling_dim, self.obj_dim)
            ]
            self.roi_fmap = nn.Sequential(*roi_fmap)
Example #7
0
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=256,
                 pooling_dim=2048,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=False,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 limit_vision=True):
        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param use_vision: Whether to use vision in the final product
        :param require_overlap_det: Whether two objects must intersect
        :param embed_dim: Dimension for all embeddings
        :param hidden_dim: LSTM hidden size
        :param obj_dim:
        """
        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'
        self.hook_for_grad = False
        self.gradients = []

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels')
            if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )
        self.ort_embedding = torch.autograd.Variable(
            get_ort_embeds(self.num_classes, 200).cuda())
        embed_vecs = obj_edge_vectors(self.classes, wv_dim=self.embed_dim)
        self.obj_embed = nn.Embedding(self.num_classes, self.embed_dim)
        self.obj_embed.weight.data = embed_vecs.clone()

        # This probably doesn't help it much
        self.pos_embed = nn.Sequential(*[
            nn.BatchNorm1d(4, momentum=BATCHNORM_MOMENTUM / 10.0),
            nn.Linear(4, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
        ])

        self.context = LinearizedContext(
            self.classes,
            self.rel_classes,
            mode=self.mode,
            embed_dim=self.embed_dim,
            hidden_dim=self.hidden_dim,
            obj_dim=self.obj_dim,
            nl_obj=nl_obj,
            nl_edge=nl_edge,
            dropout_rate=rec_dropout,
            order=order,
            pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
            pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge)

        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512)

        self.merge_obj_feats = nn.Sequential(
            nn.Linear(self.obj_dim + self.embed_dim + 128, self.hidden_dim),
            nn.ReLU())

        # self.trans = nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim//4),
        #                             LayerNorm(self.hidden_dim//4), nn.ReLU(),
        #                             nn.Linear(self.hidden_dim//4, self.hidden_dim))

        self.get_phr_feats = nn.Linear(self.pooling_dim, self.hidden_dim)

        self.embeddings4lstm = nn.Embedding(self.num_classes, self.embed_dim)

        self.lstm = nn.LSTM(input_size=self.hidden_dim + self.embed_dim,
                            hidden_size=self.hidden_dim,
                            num_layers=1)

        self.obj_mps1 = Message_Passing4OBJ(self.hidden_dim)
        # self.obj_mps2 = Message_Passing4OBJ(self.hidden_dim)
        self.get_boxes_encode = Boxes_Encode(64)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False,
                         use_relu=False,
                         use_linear=pooling_dim == 4096,
                         pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        ###################################
        # self.obj_classify_head = nn.Linear(self.pooling_dim, self.num_classes)

        # self.post_emb_s = nn.Linear(self.pooling_dim, self.pooling_dim//2)
        # self.post_emb_s.weight = torch.nn.init.xavier_normal(self.post_emb_s.weight, gain=1.0)
        # self.post_emb_o = nn.Linear(self.pooling_dim, self.pooling_dim//2)
        # self.post_emb_o.weight = torch.nn.init.xavier_normal(self.post_emb_o.weight, gain=1.0)
        # self.merge_obj_high = nn.Linear(self.hidden_dim, self.pooling_dim//2)
        # self.merge_obj_high.weight = torch.nn.init.xavier_normal(self.merge_obj_high.weight, gain=1.0)
        # self.merge_obj_low = nn.Linear(self.pooling_dim + 5 + self.embed_dim, self.pooling_dim//2)
        # self.merge_obj_low.weight = torch.nn.init.xavier_normal(self.merge_obj_low.weight, gain=1.0)
        # self.rel_compress = nn.Linear(self.pooling_dim//2 + 64, self.num_rels, bias=True)
        # self.rel_compress.weight = torch.nn.init.xavier_normal(self.rel_compress.weight, gain=1.0)
        # self.freq_gate = nn.Linear(self.pooling_dim//2 + 64, self.num_rels, bias=True)
        # self.freq_gate.weight = torch.nn.init.xavier_normal(self.freq_gate.weight, gain=1.0)

        self.post_emb_s = nn.Linear(self.pooling_dim, self.pooling_dim)
        self.post_emb_s.weight = torch.nn.init.xavier_normal(
            self.post_emb_s.weight, gain=1.0)
        self.post_emb_o = nn.Linear(self.pooling_dim, self.pooling_dim)
        self.post_emb_o.weight = torch.nn.init.xavier_normal(
            self.post_emb_o.weight, gain=1.0)
        self.merge_obj_high = nn.Linear(self.hidden_dim, self.pooling_dim)
        self.merge_obj_high.weight = torch.nn.init.xavier_normal(
            self.merge_obj_high.weight, gain=1.0)
        self.merge_obj_low = nn.Linear(self.pooling_dim + 5 + self.embed_dim,
                                       self.pooling_dim)
        self.merge_obj_low.weight = torch.nn.init.xavier_normal(
            self.merge_obj_low.weight, gain=1.0)
        self.rel_compress = nn.Linear(self.pooling_dim + 64,
                                      self.num_rels,
                                      bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)
        self.freq_gate = nn.Linear(self.pooling_dim + 64,
                                   self.num_rels,
                                   bias=True)
        self.freq_gate.weight = torch.nn.init.xavier_normal(
            self.freq_gate.weight, gain=1.0)
        # self.ranking_module = nn.Sequential(nn.Linear(self.pooling_dim + 64, self.hidden_dim), nn.ReLU(), nn.Linear(self.hidden_dim, 1))
        if self.use_bias:
            self.freq_bias = FrequencyBias()
Example #8
0
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=256,
                 pooling_dim=2048,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=False,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 limit_vision=True):
        """
        Args:
            classes: list, list of 151 object class names(including background)
            rel_classes: list, list of 51 predicate names( including background(norelationship))
            mode: string, 'sgdet', 'predcls' or 'sgcls'
            num_gpus: integer, number of GPUs to use
            use_vision: boolean, whether to use vision in the final product
            require_overlap_det: boolean, whether two object must intersect
            embed_dim: integer, number of dimension for all embeddings
            hidden_dim: integer, hidden size of LSTM
            pooling_dim: integer, outputsize of vgg fc layer
            nl_obj: integer, number of object context layer, 2 in paper
            nl_edge: integer, number of edge context layer, 4 in paper
            use_resnet: integer, use resnet for backbone
            order: string, value must be in ('size', 'confidence', 'random', 'leftright'), order of RoIs
            thresh: float, threshold for scores of boxes
                if score of box smaller than thresh, then it will be abandoned
            use_proposals: boolean, whether to use proposals
            pass_in_obj_feats_to_decoder: boolean, whether to pass object features to decoder RNN
            pass_in_obj_feats_to_edge: boolean, whether to pass object features to edge context RNN
            rec_dropout: float, dropout rate in RNN
            use_bias: boolean,
            use_tanh: boolean,
            limit_vision: boolean,
        """
        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels')
            if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )

        self.context = LinearizedContext(
            self.classes,
            self.rel_classes,
            mode=self.mode,
            embed_dim=self.embed_dim,
            hidden_dim=self.hidden_dim,
            obj_dim=self.obj_dim,
            nl_obj=nl_obj,
            nl_edge=nl_edge,
            dropout_rate=rec_dropout,
            order=order,
            pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
            pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge)

        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False,
                         use_relu=False,
                         use_linear=pooling_dim == 4096,
                         pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        ###################################
        self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2)

        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.

        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(
            0, 10.0 * math.sqrt(1.0 / self.hidden_dim))
        self.post_lstm.bias.data.zero_()

        if nl_edge == 0:
            self.post_emb = nn.Embedding(self.num_classes,
                                         self.pooling_dim * 2)
            self.post_emb.weight.data.normal_(0, math.sqrt(1.0))

        self.rel_compress = nn.Linear(self.pooling_dim,
                                      self.num_rels,
                                      bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()
Example #9
0

        
Example #10
0
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=256,
                 pooling_dim=2048,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=False,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 model_path='',
                 reachability=False,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 init_center=False,
                 limit_vision=True):

        super(RelModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode
        self.init_center = init_center
        self.pooling_size = 7
        self.model_path = model_path
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.pooling_dim = pooling_dim
        self.centroids = None
        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'
        self.global_embedding = EmbeddingImagenet(4096)
        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels')
            if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )

        self.context = LinearizedContext(
            self.classes,
            self.rel_classes,
            mode=self.mode,
            embed_dim=self.embed_dim,
            hidden_dim=self.hidden_dim,
            obj_dim=self.obj_dim,
            nl_obj=nl_obj,
            nl_edge=nl_edge,
            dropout_rate=rec_dropout,
            order=order,
            pass_in_obj_feats_to_decoder=pass_in_obj_feats_to_decoder,
            pass_in_obj_feats_to_edge=pass_in_obj_feats_to_edge)

        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512)
        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                Flattener(),
                load_vgg(use_dropout=False,
                         use_relu=False,
                         use_linear=pooling_dim == 4096,
                         pretrained=False).classifier,
            ]
            if pooling_dim != 4096:
                roi_fmap.append(nn.Linear(4096, pooling_dim))
            self.roi_fmap = nn.Sequential(*roi_fmap)
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

        ###################################
        self.post_lstm = nn.Linear(self.hidden_dim, self.pooling_dim * 2)
        self.disc_center = DiscCentroidsLoss(self.num_rels, self.pooling_dim)
        self.meta_classify = MetaEmbedding_Classifier(
            feat_dim=self.pooling_dim, num_classes=self.num_rels)
        self.disc_center_g = DiscCentroidsLoss(self.num_classes,
                                               self.pooling_dim)
        self.meta_classify_g = MetaEmbedding_Classifier(
            feat_dim=self.pooling_dim, num_classes=self.num_classes)
        self.global_sub_additive = nn.Linear(4096, 1, bias=True)
        self.global_obj_additive = nn.Linear(4096, 1, bias=True)
        # Initialize to sqrt(1/2n) so that the outputs all have mean 0 and variance 1.
        # (Half contribution comes from LSTM, half from embedding.

        # In practice the pre-lstm stuff tends to have stdev 0.1 so I multiplied this by 10.
        self.post_lstm.weight.data.normal_(
            0, 10.0 * math.sqrt(1.0 / self.hidden_dim))
        self.post_lstm.bias.data.zero_()

        self.global_logist = nn.Linear(self.pooling_dim,
                                       self.num_classes,
                                       bias=True)  # CosineLinear(4096,150)#
        self.global_logist.weight = torch.nn.init.xavier_normal(
            self.global_logist.weight, gain=1.0)

        self.post_emb = nn.Embedding(self.num_classes, self.pooling_dim * 2)
        self.post_emb.weight.data.normal_(0, math.sqrt(1.0))

        self.rel_compress = nn.Linear(self.pooling_dim,
                                      self.num_rels,
                                      bias=True)
        self.rel_compress.weight = torch.nn.init.xavier_normal(
            self.rel_compress.weight, gain=1.0)
        if self.use_bias:
            self.freq_bias = FrequencyBias()
        self.class_num = torch.zeros(len(self.classes))
        self.centroids = torch.zeros(len(self.classes),
                                     self.pooling_dim).cuda()
Example #11
0
    def __init__(self, classes, rel_classes, mode='sgdet', num_gpus=1, require_overlap_det=True,
                 embed_dim=200, use_resnet=False, order='confidence', thresh=0.01, use_proposals=False):

        """
        :param classes: Object classes
        :param rel_classes: Relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        """
        super(NODIS, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim

        self.obj_dim = 2048 if use_resnet else 4096


        self.order = 'random'

        self.require_overlap = require_overlap_det and self.mode == 'sgdet'

        self.detector = ObjectDetector(
            classes=classes,
            mode=('proposals' if use_proposals else 'refinerels') if mode == 'sgdet' else 'gtbox',
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )

        self.context = O_NODE(self.classes, self.rel_classes, mode=self.mode, embed_dim=self.embed_dim, obj_dim=self.obj_dim, order=order)

        # Image Feats (You'll have to disable if you want to turn off the features from here)
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size, stride=16,
                                              dim=1024 if use_resnet else 512)

        if use_resnet:
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier
            self.roi_avg_pool = nn.AvgPool2d(kernel_size=7, stride=0)
        ###################################
        embed_vecs = obj_edge_vectors(self.classes, wv_dim=self.embed_dim)
        self.obj_embed = nn.Embedding(self.num_classes, self.embed_dim)
        self.obj_embed.weight.data = embed_vecs.clone()

        self.obj_embed2 = nn.Embedding(self.num_classes, self.embed_dim)
        self.obj_embed2.weight.data = embed_vecs.clone()

        self.lstm_visual = nn.LSTM(input_size=1536, hidden_size=512)
        self.lstm_semantic = nn.LSTM(input_size=400, hidden_size=512)
        self.odeBlock = odeBlock(odeFunc1(bidirectional=True))

        self.fc_predicate = nn.Sequential(nn.Linear(1024, 512),
                                          nn.ReLU(inplace=False),
                                          nn.Linear(512, 51),
                                          nn.ReLU(inplace=False))
Example #12
0
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=True,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=256,
                 obj_dim=2048,
                 pooling_dim=4096,
                 nl_obj=1,
                 nl_edge=2,
                 use_resnet=True,
                 order='confidence',
                 thresh=0.01,
                 use_proposals=False,
                 pass_in_obj_feats_to_decoder=True,
                 pass_in_obj_feats_to_edge=True,
                 rec_dropout=0.0,
                 use_bias=True,
                 use_tanh=True,
                 limit_vision=True,
                 spatial_dim=128,
                 graph_constrain=True,
                 mp_iter_num=1):
        """
        Args:
            mp_iter_num: integer, number of message passing iteration
        """
        super(FckModel, self).__init__()
        self.classes = classes
        self.rel_classes = rel_classes
        self.num_gpus = num_gpus
        assert mode in MODES
        self.mode = mode

        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = obj_dim
        self.pooling_dim = 2048 if use_resnet else 4096
        self.spatial_dim = spatial_dim

        self.use_bias = use_bias
        self.use_vision = use_vision
        self.use_tanh = use_tanh
        self.limit_vision = limit_vision
        self.require_overlap = require_overlap_det and self.mode == 'sgdet'

        self.graph_cons = graph_constrain
        self.mp_iter_num = mp_iter_num

        classes_word_vec = obj_edge_vectors(self.classes, wv_dim=embed_dim)
        self.classes_word_embedding = nn.Embedding(self.num_classes, embed_dim)
        self.classes_word_embedding.weight.data = classes_word_vec.clone()
        self.classes_word_embedding.weight.requires_grad = False

        # the last one is dirty bit
        self.rel_mem = nn.Embedding(self.num_rels, self.obj_dim + 1)
        self.rel_mem.weight.data[:, -1] = 0

        if mode == 'sgdet':
            if use_proposals:
                obj_detector_mode = 'proposals'
            else:
                obj_detector_mode = 'refinerels'
        else:
            obj_detector_mode = 'gtbox'

        self.detector = ObjectDetector(
            classes=classes,
            mode=obj_detector_mode,
            use_resnet=use_resnet,
            thresh=thresh,
            max_per_img=64,
        )
        self.union_boxes = UnionBoxesAndFeats(pooling_size=self.pooling_size,
                                              stride=16,
                                              dim=1024 if use_resnet else 512,
                                              use_feats=False)
        self.spatial_fc = nn.Sequential(*[
            nn.Linear(4, spatial_dim),
            nn.BatchNorm1d(spatial_dim, momentum=BATCHNORM_MOMENTUM / 10.),
            nn.ReLU(inplace=True)
        ])
        self.word_fc = nn.Sequential(*[
            nn.Linear(2 * embed_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim, momentum=BATCHNORM_MOMENTUM / 10.),
            nn.ReLU(inplace=True)
        ])
        # union box feats
        feats_dim = obj_dim + spatial_dim + hidden_dim
        self.relpn_fc = nn.Linear(feats_dim, 2)
        self.relcnn_fc1 = nn.Sequential(
            *[nn.Linear(feats_dim, feats_dim),
              nn.ReLU(inplace=True)])
        self.box_mp_fc = nn.Sequential(*[
            nn.Linear(obj_dim, obj_dim),
        ])
        self.sub_rel_mp_fc = nn.Sequential(*[nn.Linear(feats_dim, obj_dim)])

        self.obj_rel_mp_fc = nn.Sequential(*[
            nn.Linear(feats_dim, obj_dim),
        ])

        self.mp_atten_fc = nn.Sequential(*[
            nn.Linear(feats_dim + obj_dim, obj_dim),
            nn.ReLU(inplace=True),
            nn.Linear(obj_dim, 1)
        ])

        self.cls_fc = nn.Linear(obj_dim, self.num_classes)
        self.relcnn_fc2 = nn.Linear(
            feats_dim, self.num_rels if self.graph_cons else 2 * self.num_rels)

        if use_resnet:
            #deprecate
            self.roi_fmap = nn.Sequential(
                resnet_l4(relu_end=False),
                nn.AvgPool2d(self.pooling_size),
                Flattener(),
            )
        else:
            roi_fmap = [
                load_vgg(
                    use_dropout=False,
                    use_relu=False,
                    use_linear=self.obj_dim == 4096,
                    pretrained=False,
                ).classifier,
                nn.Linear(self.pooling_dim, self.obj_dim)
            ]
            self.roi_fmap = nn.Sequential(*roi_fmap)