Esempio n. 1
0
    def __init__(self, num_answers, fn_type="softmax"):
        super().__init__()
        
        # Build LXRT encoder
        self.lxrt_encoder = LXRTEncoder(
            args,
            max_seq_length=MAX_VQA_LENGTH
        )
        
        hid_dim = self.lxrt_encoder.dim
        print("Size of Hidden Dimension:",hid_dim)
        fc_dim = int(hid_dim)
        print("Size of Hidden Dimension:",fc_dim)
        
        # Type Predictor
        self.type_fc = nn.Sequential(
            nn.Linear(hid_dim, hid_dim * 2),
            GeLU(),
            BertLayerNorm(hid_dim * 2, eps=1e-12),
            nn.Linear(hid_dim * 2, 4)
        )
        
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax()
        
        if fn_type=="tanh":
            self.fn =self.tanh
            print("FN: TANH")
        elif fn_type=="softmax":
            self.fn= self.softmax
            print("FN: SOFTMAX")
        else:
            self.fn = self.sigmoid
            print("FN: SIGMOID")
        
        # YESNO feedforward
        self.yesno_fc = nn.Sequential(
            nn.Linear(hid_dim, hid_dim *2),
            GeLU(),
            BertLayerNorm(hid_dim *2, eps=1e-12), 
            nn.Linear(2*hid_dim, fc_dim),
            GeLU(),
            BertLayerNorm(fc_dim, eps=1e-12)
        )

        # NUMBER feedforward
        self.number_fc = nn.Sequential(
            nn.Linear(hid_dim, hid_dim *2),
            GeLU(),
            BertLayerNorm(hid_dim *2, eps=1e-12), 
            nn.Linear(2*hid_dim, fc_dim),
            GeLU(),
            BertLayerNorm(fc_dim, eps=1e-12)
        )

        # OTHER feedforward
        self.other_fc = nn.Sequential(
            nn.Linear(hid_dim, hid_dim *2),
            GeLU(),
            BertLayerNorm(hid_dim *2, eps=1e-12), 
            nn.Linear(2*hid_dim, fc_dim),
            GeLU(),
            BertLayerNorm(fc_dim, eps=1e-12)
        )  

         # OTHER feedforward
        self.color_fc = nn.Sequential(
            nn.Linear(hid_dim, hid_dim *2),
            GeLU(),
            BertLayerNorm(hid_dim *2, eps=1e-12), 
            nn.Linear(2*hid_dim, fc_dim),
            GeLU(),
            BertLayerNorm(fc_dim, eps=1e-12)
        ) 
        
        # Answering Heads
        self.logit_fc1 = nn.Sequential(
            nn.Linear(5*fc_dim, hid_dim * 2),
            GeLU(),
            BertLayerNorm(hid_dim * 2, eps=1e-12),
            nn.Linear(hid_dim * 2, hid_dim)
        )

        # Answering Heads
        self.logit_fc = nn.Sequential(
            nn.Linear(hid_dim, hid_dim * 2),
            GeLU(),
            BertLayerNorm(hid_dim * 2, eps=1e-12),
            nn.Linear(hid_dim * 2, num_answers)
        )
        # self.logit_fc = nn.Sequential(
        #     nn.Linear(hid_dim, hid_dim * 3),
        #     GeLU(),
        #     BertLayerNorm(hid_dim * 3, eps=1e-12),
        #     nn.Linear(hid_dim * 3, num_answers)
        # )


        self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights)
    def __init__(self, num_answers, fn_type="softmax"):
        super().__init__()

        # Build LXRT encoder
        self.lxrt_encoder = LXRTEncoder(args, max_seq_length=MAX_VQA_LENGTH)

        hid_dim = self.lxrt_encoder.dim
        print("Size of Hidden Dimension:", hid_dim)
        fc_dim = int(hid_dim)
        print("Size of Hidden Dimension:", fc_dim)

        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax()

        if fn_type == "tanh":
            self.fn = self.tanh
            print("FN: TANH")
        elif fn_type == "softmax":
            self.fn = self.softmax
            print("FN: SOFTMAX")
        else:
            self.fn = self.sigmoid
            print("FN: SIGMOID")

        # YN:AND/OR/NOT/NONE Type Predictor
        self.yn_fc = nn.Sequential(nn.Linear(hid_dim, hid_dim * 2), GeLU(),
                                   BertLayerNorm(hid_dim * 2, eps=1e-12),
                                   nn.Linear(hid_dim * 2, 4))

        # AND FF
        self.and_fc = nn.Sequential(nn.Linear(hid_dim, hid_dim * 2), GeLU(),
                                    BertLayerNorm(hid_dim * 2, eps=1e-12),
                                    nn.Linear(2 * hid_dim, fc_dim), GeLU(),
                                    BertLayerNorm(fc_dim, eps=1e-12))

        # OR FF
        self.or_fc = nn.Sequential(nn.Linear(hid_dim, hid_dim * 2), GeLU(),
                                   BertLayerNorm(hid_dim * 2, eps=1e-12),
                                   nn.Linear(2 * hid_dim, fc_dim), GeLU(),
                                   BertLayerNorm(fc_dim, eps=1e-12))

        # NOT FF
        self.not_fc = nn.Sequential(nn.Linear(hid_dim, hid_dim * 2), GeLU(),
                                    BertLayerNorm(hid_dim * 2, eps=1e-12),
                                    nn.Linear(2 * hid_dim, fc_dim), GeLU(),
                                    BertLayerNorm(fc_dim, eps=1e-12))

        # NONE FF
        self.none_fc = nn.Sequential(nn.Linear(hid_dim, hid_dim * 2), GeLU(),
                                     BertLayerNorm(hid_dim * 2, eps=1e-12),
                                     nn.Linear(2 * hid_dim, fc_dim), GeLU(),
                                     BertLayerNorm(fc_dim, eps=1e-12))

        # Answering Heads
        self.logit_fc1 = nn.Sequential(nn.Linear(6 * fc_dim, hid_dim * 2),
                                       GeLU(),
                                       BertLayerNorm(hid_dim * 2, eps=1e-12),
                                       nn.Linear(hid_dim * 2, hid_dim))

        self.logit_fc = nn.Sequential(nn.Linear(hid_dim, hid_dim * 2), GeLU(),
                                      BertLayerNorm(hid_dim * 2, eps=1e-12),
                                      nn.Linear(hid_dim * 2, num_answers))

        self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights)