def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 require_overlap_det=True,
                 depth_model=None,
                 pretrained_depth=False,
                 **kwargs):
        """
        :param classes: object classes
        :param rel_classes: relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param require_overlap_det: whether two objects must intersect
        :param depth_model: provided architecture for depth feature extraction
        :param pretrained_depth: Whether the depth feature extractor should be initialized with ImageNet weights
        """
        RelModelBase.__init__(self, classes, rel_classes, mode, num_gpus,
                              require_overlap_det)

        # -- Store depth related parameters
        assert depth_model in DEPTH_MODELS
        self.depth_model = depth_model
        self.pretrained_depth = pretrained_depth
        self.depth_pooling_dim = DEPTH_DIMS[self.depth_model]
        self.pooling_size = 7
        self.detector = nn.Module()

        # -- Initialize depth backbone
        self.depth_backbone = DepthCNN(depth_model=self.depth_model,
                                       pretrained=self.pretrained_depth)

        # -- Create a relation head which is used to carry on the feature extraction
        # from RoIs of depth features
        self.depth_rel_head = self.depth_backbone.get_classifier()

        # -- Define depth features hidden layer
        self.depth_rel_hlayer = nn.Sequential(*[
            xavier_init(
                nn.Linear(self.depth_pooling_dim * 2, self.FC_SIZE_DEPTH)),
            nn.ReLU(inplace=True),
            nn.Dropout(0.6),
        ])

        # -- Final FC layer which predicts the relations
        self.depth_rel_out = xavier_init(
            nn.Linear(self.FC_SIZE_DEPTH, self.num_rels, bias=True))

        # -- Freeze the backbone (Pre-trained mode)
        if self.pretrained_depth:
            self.freeze_module(self.depth_backbone)
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 require_overlap_det=True,
                 depth_model=None,
                 pretrained_depth=False,
                 **kwargs):
        """
        :param classes: object classes
        :param rel_classes: relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param require_overlap_det: Whether two objects must intersect
        :param depth_model: provided architecture for depth feature extraction
        :param pretrained_depth: Whether the depth feature extractor should be initialized with ImageNet weights
        """
        RelModelBase.__init__(self, classes, rel_classes, mode, num_gpus,
                              require_overlap_det)

        # -- Store depth related parameters
        assert depth_model in DEPTH_MODELS
        self.depth_model = depth_model
        self.pretrained_depth = pretrained_depth
        self.depth_pooling_dim = DEPTH_DIMS[self.depth_model]
        self.depth_channels = DEPTH_CHANNELS[self.depth_model]
        self.pooling_size = 7
        self.detector = nn.Module()

        # -- Initialize depth backbone
        self.depth_backbone = DepthCNN(depth_model=self.depth_model,
                                       pretrained=self.pretrained_depth)

        # -- Union of Bounding boxes feature extractor
        self.depth_union_boxes = UnionBoxesAndFeats(
            pooling_size=self.pooling_size, stride=16, dim=self.depth_channels)

        # -- Create a relation head which is used to carry on the feature extraction
        # from union features of depth features
        self.depth_rel_head_union = self.depth_backbone.get_classifier()

        # -- Final FC layer which predicts the relations
        self.depth_rel_out = xavier_init(
            nn.Linear(self.depth_pooling_dim, self.num_rels, bias=True))

        # -- Freeze the backbone (Pre-trained mode)
        if self.pretrained_depth:
            self.freeze_module(self.depth_backbone)
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 use_vision=False,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=4096,
                 use_resnet=False,
                 thresh=0.01,
                 use_proposals=False,
                 use_bias=True,
                 limit_vision=True,
                 depth_model=None,
                 pretrained_depth=False,
                 active_features=None,
                 frozen_features=None,
                 use_embed=False,
                 **kwargs):
        """
        :param classes: object classes
        :param rel_classes: relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param use_vision: enable the contribution of union of bounding boxes
        :param require_overlap_det: whether two objects must intersect
        :param embed_dim: word2vec embeddings dimension
        :param hidden_dim: dimension of the fusion hidden layer
        :param use_resnet: use resnet as faster-rcnn's backbone
        :param thresh: faster-rcnn related threshold (Threshold for calling it a good box)
        :param use_proposals: whether to use region proposal candidates
        :param use_bias: enable frequency bias
        :param limit_vision: use truncated version of UoBB features
        :param depth_model: provided architecture for depth feature extraction
        :param pretrained_depth: whether the depth feature extractor should be initialized with ImageNet weights
        :param active_features: what set of features should be enabled (e.g. 'vdl' : visual, depth, and location features)
        :param frozen_features: what set of features should be frozen (e.g. 'd' : depth)
        :param use_embed: use word2vec embeddings
        """
        RelModelBase.__init__(self, classes, rel_classes, mode, num_gpus,
                              require_overlap_det, active_features,
                              frozen_features)
        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096
        self.use_vision = use_vision
        self.use_bias = use_bias
        self.limit_vision = limit_vision

        # -- Store depth related parameters
        assert depth_model in DEPTH_MODELS
        self.depth_model = depth_model
        self.pretrained_depth = pretrained_depth
        self.depth_pooling_dim = DEPTH_DIMS[self.depth_model]
        self.use_embed = use_embed
        self.detector = nn.Module()
        features_size = 0

        # -- Check whether ResNet is selected as faster-rcnn's backbone
        if use_resnet:
            raise ValueError(
                "The current model does not support ResNet as the Faster-RCNN's backbone."
            )
        """ *** DIFFERENT COMPONENTS OF THE PROPOSED ARCHITECTURE *** 
        This is the part where the different components of the proposed relation detection 
        architecture are defined. In the case of RGB images, we have class probability distribution
        features, visual features, and the location ones. If we are considering depth images as well,
        we augment depth features too. """

        # -- Visual features
        if self.has_visual:
            # -- Define faster R-CNN network and it's related feature extractors
            self.detector = ObjectDetector(
                classes=classes,
                mode=('proposals' if use_proposals else 'refinerels')
                if mode == 'sgdet' else 'gtbox',
                use_resnet=use_resnet,
                thresh=thresh,
                max_per_img=64,
            )
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

            # -- Define union features
            if self.use_vision:
                # -- UoBB pooling module
                self.union_boxes = UnionBoxesAndFeats(
                    pooling_size=self.pooling_size,
                    stride=16,
                    dim=1024 if use_resnet else 512)

                # -- UoBB feature extractor
                roi_fmap = [
                    Flattener(),
                    load_vgg(use_dropout=False,
                             use_relu=False,
                             use_linear=self.hidden_dim == 4096,
                             pretrained=False).classifier,
                ]
                if self.hidden_dim != 4096:
                    roi_fmap.append(nn.Linear(4096, self.hidden_dim))
                self.roi_fmap = nn.Sequential(*roi_fmap)

            # -- Define visual features hidden layer
            self.visual_hlayer = nn.Sequential(*[
                xavier_init(nn.Linear(self.obj_dim * 2, self.FC_SIZE_VISUAL)),
                nn.ReLU(inplace=True),
                nn.Dropout(0.8)
            ])
            self.visual_scale = ScaleLayer(1.0)
            features_size += self.FC_SIZE_VISUAL

        # -- Location features
        if self.has_loc:
            # -- Define location features hidden layer
            self.location_hlayer = nn.Sequential(*[
                xavier_init(nn.Linear(self.LOC_INPUT_SIZE, self.FC_SIZE_LOC)),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1)
            ])
            self.location_scale = ScaleLayer(1.0)
            features_size += self.FC_SIZE_LOC

        # -- Class features
        if self.has_class:
            if self.use_embed:
                # -- Define class embeddings
                embed_vecs = obj_edge_vectors(self.classes,
                                              wv_dim=self.embed_dim)
                self.obj_embed = nn.Embedding(self.num_classes, self.embed_dim)
                self.obj_embed.weight.data = embed_vecs.clone()

            classme_input_dim = self.embed_dim if self.use_embed else self.num_classes
            # -- Define Class features hidden layer
            self.classme_hlayer = nn.Sequential(*[
                xavier_init(
                    nn.Linear(classme_input_dim * 2, self.FC_SIZE_CLASS)),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1)
            ])
            self.classme_scale = ScaleLayer(1.0)
            features_size += self.FC_SIZE_CLASS

        # -- Depth features
        if self.has_depth:
            # -- Initialize depth backbone
            self.depth_backbone = DepthCNN(depth_model=self.depth_model,
                                           pretrained=self.pretrained_depth)

            # -- Create a relation head which is used to carry on the feature extraction
            # from RoIs of depth features
            self.depth_rel_head = self.depth_backbone.get_classifier()

            # -- Define depth features hidden layer
            self.depth_rel_hlayer = nn.Sequential(*[
                xavier_init(
                    nn.Linear(self.depth_pooling_dim * 2, self.FC_SIZE_DEPTH)),
                nn.ReLU(inplace=True),
                nn.Dropout(0.6),
            ])
            self.depth_scale = ScaleLayer(1.0)
            features_size += self.FC_SIZE_DEPTH

        # -- Initialize frequency bias if needed
        if self.use_bias:
            self.freq_bias = FrequencyBias()

        # -- *** Fusion layer *** --
        # -- A hidden layer for concatenated features (fusion features)
        self.fusion_hlayer = nn.Sequential(*[
            xavier_init(nn.Linear(features_size, self.hidden_dim)),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1)
        ])

        # -- Final FC layer which predicts the relations
        self.rel_out = xavier_init(
            nn.Linear(self.hidden_dim, self.num_rels, bias=True))

        # -- Freeze the user specified features
        if self.frz_visual:
            self.freeze_module(self.detector)
            self.freeze_module(self.roi_fmap_obj)
            self.freeze_module(self.visual_hlayer)
            if self.use_vision:
                self.freeze_module(self.roi_fmap)
                self.freeze_module(self.union_boxes.conv)

        if self.frz_class:
            self.freeze_module(self.classme_hlayer)

        if self.frz_loc:
            self.freeze_module(self.location_hlayer)

        if self.frz_depth:
            self.freeze_module(self.depth_backbone)
            self.freeze_module(self.depth_rel_head)
            self.freeze_module(self.depth_rel_hlayer)