Example #1
0
    def __init__(self, feature_dim, output_dims, downsample_rate, args=None):
        super().__init__()
        self.pool_size = 7
        self.feature_dim = feature_dim
        self.output_dims = output_dims
        self.downsample_rate = downsample_rate
        self.args = args
        if self.args.rel_box_flag:
            self.col_fuse = nn.Linear(128 * 4 * 4, output_dims[1])

        self.object_roi_pool = jacnn.PrRoIPool2D(self.pool_size,
                                                 self.pool_size,
                                                 1.0 / downsample_rate)
        self.context_roi_pool = jacnn.PrRoIPool2D(self.pool_size,
                                                  self.pool_size,
                                                  1.0 / downsample_rate)
        self.relation_roi_pool = jacnn.PrRoIPool2D(self.pool_size,
                                                   self.pool_size,
                                                   1.0 / downsample_rate)

        if not DEBUG:
            self.context_feature_extract = nn.Conv2d(feature_dim, feature_dim,
                                                     1)
            self.relation_feature_extract = nn.Conv2d(feature_dim,
                                                      feature_dim // 2 * 3, 1)

            self.object_feature_fuse = nn.Conv2d(feature_dim * 2,
                                                 output_dims[1], 1)
            self.relation_feature_fuse = nn.Conv2d(
                feature_dim // 2 * 3 + output_dims[1] * 2, output_dims[2], 1)

            self.object_feature_fc = nn.Sequential(
                nn.ReLU(True),
                nn.Linear(output_dims[1] * self.pool_size**2, output_dims[1]))
            self.relation_feature_fc = nn.Sequential(
                nn.ReLU(True),
                nn.Linear(output_dims[2] * self.pool_size**2, output_dims[2]))

            self.reset_parameters()
        else:

            def gen_replicate(n):
                def rep(x):
                    return torch.cat([x for _ in range(n)], dim=1)

                return rep

            self.pool_size = 32
            self.object_roi_pool = jacnn.PrRoIPool2D(32, 32,
                                                     1.0 / downsample_rate)
            self.context_roi_pool = jacnn.PrRoIPool2D(32, 32,
                                                      1.0 / downsample_rate)
            self.relation_roi_pool = jacnn.PrRoIPool2D(32, 32,
                                                       1.0 / downsample_rate)
            self.context_feature_extract = gen_replicate(2)
            self.relation_feature_extract = gen_replicate(3)
            self.object_feature_fuse = jacnn.Identity()
            self.relation_feature_fuse = jacnn.Identity()
Example #2
0
    def __init__(self, vocab, configs):
        super().__init__()
        self.vocab = vocab

        #pdb.set_trace()

        import jactorch.models.vision.resnet as resnet
        self.resnet = resnet.resnet34(pretrained=True,
                                      incl_gap=False,
                                      num_classes=None)
        self.resnet.layer4 = jacnn.Identity()

        import nscl.nn.scene_graph.scene_graph as sng
        # number of channels = 256; downsample rate = 16.
        self.scene_graph = sng.SceneGraph(256, configs.model.sg_dims, 16)

        import nscl.nn.reasoning_v1.quasi_symbolic as qs
        self.reasoning = qs.DifferentiableReasoning(
            self._make_vse_concepts(configs.model.vse_large_scale,
                                    configs.model.vse_known_belong),
            self.scene_graph.output_dims, configs.model.vse_hidden_dims)

        import nscl.nn.reasoning_v1.losses as vqa_losses
        self.scene_loss = vqa_losses.SceneParsingLoss(
            gdef.all_concepts,
            add_supervision=configs.train.scene_add_supervision)
        self.qa_loss = vqa_losses.QALoss(
            add_supervision=configs.train.qa_add_supervision)
Example #3
0
    def build(self):

        if self.use_vision:
            import jactorch.models.vision.resnet as resnet
            self.resnet = resnet.resnet34(pretrained=True,
                                          incl_gap=False,
                                          num_classes=None)
            self.resnet.layer4 = jacnn.Identity()

            self.mlp = jacnn.MLPLayer(256 + 128 * 2, len(self.tools.answers),
                                      [512])
        else:
            self.mlp = jacnn.MLPLayer(128 * 2, len(self.tools.answers), [256])

        padding_idx = self.tools.words['<NULL>']
        self.embedding = nn.Embedding(self.num_vocab,
                                      self.dim,
                                      padding_idx=padding_idx)

        self.gru = jacnn.GRULayer(self.dim,
                                  128,
                                  1,
                                  bidirectional=True,
                                  batch_first=True,
                                  dropout=0.1)

        self.loss_fn = F.nll_loss

        if self.use_lm:
            self.gru_dropout = nn.Dropout(0.1)
            self.decode = nn.Linear(128 * 2, self.num_vocab)
            self.decode.bias.data.zero_()
            self.decode_loss = jacnn.CrossEntropyLoss(average='none')
Example #4
0
    def __init__(self, vocab, configs, args=None):
        super().__init__()
        self.vocab = vocab
        self.args = args
        #pdb.set_trace()

        import jactorch.models.vision.resnet as resnet
        self.resnet = resnet.resnet34(pretrained=True,
                                      incl_gap=False,
                                      num_classes=None)
        self.resnet.layer4 = jacnn.Identity()

        import clevrer.models.scene_graph as sng
        # number of channels = 256; downsample rate = 16.
        #pdb.set_trace()
        self.scene_graph = sng.SceneGraph(256,
                                          configs.model.sg_dims,
                                          16,
                                          args=configs)

        #pdb.set_trace()

        import clevrer.models.quasi_symbolic as qs
        if configs.rel_box_flag:
            self.scene_graph.output_dims[
                2] = self.scene_graph.output_dims[2] * 2
        if configs.dynamic_ftr_flag and (
                not self.args.box_only_for_collision_flag):
            self.scene_graph.output_dims[2] = self.scene_graph.output_dims[
                2] + self.scene_graph.output_dims[3] * 4
        elif configs.dynamic_ftr_flag and self.args.box_only_for_collision_flag:
            self.scene_graph.output_dims[
                2] = self.scene_graph.output_dims[3] * 4

        if self.args.box_iou_for_collision_flag:
            box_dim = 4
            self.scene_graph.output_dims[2] += int(
                self.scene_graph.output_dims[3] / box_dim)

        self.reasoning = qs.DifferentiableReasoning(
            self._make_vse_concepts(configs.model.vse_large_scale,
                                    configs.model.vse_known_belong),
            self.scene_graph.output_dims,
            configs.model.vse_hidden_dims,
            args=self.args)

        import clevrer.losses as vqa_losses
        self.scene_loss = vqa_losses.SceneParsingLoss(
            gdef.all_concepts_clevrer,
            add_supervision=configs.train.scene_add_supervision,
            args=self.args)
        self.qa_loss = vqa_losses.QALoss(
            add_supervision=configs.train.qa_add_supervision)
Example #5
0
    def __init__(self, configs, args=None):
        super().__init__()
        self.args = args
        configs.colli_ftr_type = args.colli_ftr_type
        import jactorch.models.vision.resnet as resnet
        self.resnet = resnet.resnet34(pretrained=True,
                                      incl_gap=False,
                                      num_classes=None)
        self.resnet.layer4 = jacnn.Identity()

        import clevrer.models.scene_graph as sng
        # number of channels = 256; downsample rate = 16.
        self.scene_graph = sng.SceneGraph(256,
                                          configs.model.sg_dims,
                                          16,
                                          args=configs)

        import clevrer.models.quasi_symbolic_v2 as qs
        ftr_dim = self.scene_graph.output_dims[3]
        box_dim = 4
        time_step = int(ftr_dim / box_dim)
        offset = time_step % self.args.smp_coll_frm_num
        seg_frm_num = int((time_step - offset) / self.args.smp_coll_frm_num)

        if configs.rel_box_flag:
            self.scene_graph.output_dims[
                2] = self.scene_graph.output_dims[2] * 2
        if configs.dynamic_ftr_flag and (
                not self.args.box_only_for_collision_flag):
            self.scene_graph.output_dims[2] = self.scene_graph.output_dims[
                2] + seg_frm_num * 4 * box_dim
        elif configs.dynamic_ftr_flag and self.args.box_only_for_collision_flag:
            self.scene_graph.output_dims[2] = seg_frm_num * 4 * box_dim

        if self.args.box_iou_for_collision_flag:
            box_dim = 4
            self.scene_graph.output_dims[2] += seg_frm_num

        self.reasoning = qs.DifferentiableReasoning(
            self._make_vse_concepts(configs.model.vse_known_belong),
            self.scene_graph.output_dims,
            configs.model.vse_hidden_dims,
            args=self.args,
            seg_frm_num=seg_frm_num)
        #pdb.set_trace()
        import clevrer.losses_v2 as vqa_losses
        self.scene_loss = vqa_losses.SceneParsingLoss(
            gdef.all_concepts_clevrer,
            add_supervision=configs.train.scene_add_supervision,
            args=self.args)
        self.qa_loss = vqa_losses.QALoss(
            add_supervision=configs.train.qa_add_supervision, args=self.args)
    def __init__(self,
                 feature_dim,
                 output_dims,
                 downsample_rate,
                 object_supervision=False,
                 concatenative_pair_representation=True):
        super().__init__()
        self.pool_size = 7
        self.feature_dim = feature_dim
        self.output_dims = output_dims
        self.downsample_rate = downsample_rate

        self.object_supervision = object_supervision
        self.concatenative_pair_representation = concatenative_pair_representation

        if self.object_supervision:
            self.object_roi_pool = jacnn.PrRoIPool2D(self.pool_size,
                                                     self.pool_size,
                                                     1.0 / downsample_rate)
            self.context_roi_pool = jacnn.PrRoIPool2D(self.pool_size,
                                                      self.pool_size,
                                                      1.0 / downsample_rate)
            self.relation_roi_pool = jacnn.PrRoIPool2D(self.pool_size,
                                                       self.pool_size,
                                                       1.0 / downsample_rate)

            if not DEBUG:
                self.context_feature_extract = nn.Conv2d(
                    feature_dim, feature_dim, 1)
                self.relation_feature_extract = nn.Conv2d(
                    feature_dim, feature_dim // 2 * 3, 1)

                self.object_feature_fuse = nn.Conv2d(feature_dim * 2,
                                                     output_dims[1], 1)
                self.relation_feature_fuse = nn.Conv2d(
                    feature_dim // 2 * 3 + output_dims[1] * 2, output_dims[2],
                    1)

                self.object_feature_fc = nn.Sequential(
                    nn.ReLU(True),
                    nn.Linear(output_dims[1] * self.pool_size**2,
                              output_dims[1]))
                self.relation_feature_fc = nn.Sequential(
                    nn.ReLU(True),
                    nn.Linear(output_dims[2] * self.pool_size**2,
                              output_dims[2]))

                self.obj1_linear = nn.Linear(output_dims[1], output_dims[1])
                self.obj2_linear = nn.Linear(output_dims[1], output_dims[1])

                self.reset_parameters()
            else:

                def gen_replicate(n):
                    def rep(x):
                        return torch.cat([x for _ in range(n)], dim=1)

                    return rep

                self.pool_size = 32
                self.object_roi_pool = jacnn.PrRoIPool2D(
                    32, 32, 1.0 / downsample_rate)
                self.context_roi_pool = jacnn.PrRoIPool2D(
                    32, 32, 1.0 / downsample_rate)
                self.relation_roi_pool = jacnn.PrRoIPool2D(
                    32, 32, 1.0 / downsample_rate)
                self.context_feature_extract = gen_replicate(2)
                self.relation_feature_extract = gen_replicate(3)
                self.object_feature_fuse = jacnn.Identity()
                self.relation_feature_fuse = jacnn.Identity()

        else:
            self.num_objects_upperbound = 11
            self.temperature = 6
            self.object_coord_fuse = nn.Sequential(
                nn.Conv2d(feature_dim + 2, feature_dim, kernel_size=1),
                nn.ReLU(True))
            self.query = nn.Parameter(
                torch.randn(self.num_objects_upperbound, feature_dim))
            self.object_features_layer = nn.Sequential(
                nn.Linear(feature_dim, output_dims[1]), nn.ReLU(True))
            self.obj1_linear = nn.Linear(output_dims[1], output_dims[1])
            self.obj2_linear = nn.Linear(output_dims[1], output_dims[1])
            self.reset_parameters()