def build_ban(dataset, num_hid, op='', gamma=4, task='vqa', use_counter=True):
    #w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, .0, op)
    #q_emb = QuestionEmbedding(300 if 'c' not in op else 600, num_hid, 1, False, .0)
    w_emb = AlbertTokenizer.from_pretrained('albert-large-v2')
    q_emb = AlbertModel.from_pretrained('albert-large-v2')
    params_set = set()
    for param in q_emb.parameters():
        params_set.add(param)
        param.requires_grad = False
    v_att = BiAttention(dataset.v_dim, num_hid, num_hid, gamma)
    if task == 'vqa':
        b_net = []
        q_prj = []
        c_prj = []
        objects = 10  # minimum number of boxes
        for i in range(gamma):
            b_net.append(BCNet(dataset.v_dim, num_hid, num_hid, None, k=1))
            q_prj.append(FCNet([num_hid, num_hid], '', .2))
            c_prj.append(FCNet([objects + 1, num_hid], 'ReLU', .0))
        classifier = SimpleClassifier(num_hid, num_hid * 2,
                                      dataset.num_ans_candidates, .5)
        counter = Counter(objects) if use_counter else None
        return BanModel(dataset, params_set, w_emb, q_emb, v_att, b_net, q_prj,
                        c_prj, classifier, counter, op, gamma)
    elif task == 'flickr':
        return BanModel_flickr(w_emb, q_emb, v_att, op, gamma)
Example #2
0
def build_ban(num_token,
              v_dim,
              num_hid,
              num_ans,
              op='',
              gamma=4,
              reasoning=False):
    w_emb = WordEmbedding(num_token, 300, .0, op)
    q_emb = QuestionEmbedding(300 if 'c' not in op else 600, num_hid, 1, False,
                              .0)
    if not reasoning:
        v_att = BiAttention(v_dim, num_hid, num_hid, gamma)
    else:
        v_att = BiAttention(v_dim, num_hid, num_hid, 1)

    # constructing the model
    b_net = []
    q_prj = []
    c_prj = []
    objects = 36  # minimum number of boxes, originally 10
    for i in range(gamma):
        b_net.append(BCNet(v_dim, num_hid, num_hid, None, k=1))
        q_prj.append(FCNet([num_hid, num_hid], '', .2))
        c_prj.append(FCNet([objects + 1, num_hid], 'ReLU', .0))
    classifier = SimpleClassifier(num_hid, num_hid * 2, num_ans, .5)
    counter = Counter(objects)
    if not reasoning:
        return BanModel(w_emb, q_emb, v_att, b_net, q_prj, c_prj, classifier,
                        counter, op, gamma, num_hid)
    else:
        return BanModel_Reasoning(w_emb, q_emb, v_att, b_net, q_prj, c_prj,
                                  classifier, counter, op, gamma, num_hid)
def build_ban(dataset, num_hid, op='', gamma=4, task='vqa'):
    w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, .0, op)
    q_emb = QuestionEmbedding(300 if 'c' not in op else 600, num_hid, 1, False,
                              .0)
    v_att = NewAttention(dataset.v_dim, num_hid, num_hid, dropout=0.2)
    q_net = FCNet([q_emb.num_hid, num_hid])
    v_net = FCNet([dataset.v_dim, num_hid])

    q_att = weight_norm(nn.Linear(num_hid, 1), dim=None)
    if task == 'vqa':
        b_net = []
        q_prj = []
        c_prj = []
        objects = 10  # minimum number of boxes
        for i in range(gamma):
            b_net.append(BCNet(dataset.v_dim, num_hid, num_hid, None, k=1))
            q_prj.append(FCNet([num_hid, num_hid], '', .2))
            c_prj.append(FCNet([objects + 1, num_hid], 'ReLU', .0))
        classifier = SimpleClassifier(num_hid, num_hid * 2, 3, .5)
        #classifier2 = SimpleClassifier(num_hid, num_hid * 2, dataset.num_ans_candidates, .5)
        counter = Counter(objects)
        return BanModel(dataset, w_emb, q_emb, v_att, q_att, b_net, q_prj,
                        c_prj, q_net, v_net, classifier, counter, op, gamma)
    elif task == 'flickr':
        return BanModel_flickr(w_emb, q_emb, v_att, op, gamma)
Example #4
0
    def __init__(self, x_dim, y_dim, z_dim, glimpse, dropout=[.2, .5]):
        super(BiAttention_both, self).__init__()

        self.glimpse = glimpse
        self.logits_v = weight_norm(BCNet(x_dim, y_dim, z_dim, glimpse, dropout=dropout, k=3), \
            name='h_mat', dim=None)
        self.logits_q = weight_norm(BCNet_q(v_num, x_dim, y_dim, z_dim, glimpse, dropout=dropout, k=3), \
            name='h_mat', dim=None)
Example #5
0
    def __init__(self, x_dim, y_dim, z_dim, glimpse, dropout=[0.2, 0.5]):
        super(BiAttention, self).__init__()

        self.glimpse = glimpse
        self.logits = weight_norm(
            BCNet(x_dim, y_dim, z_dim, glimpse, dropout=dropout, k=3),
            name="h_mat",
            dim=None,
        )
Example #6
0
def build_BAN(dataset, args, priotize_using_counter=False):
    # init word embedding module, question embedding module, and Attention network
    w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, .0, args.op)
    q_emb = QuestionEmbedding(300 if 'c' not in args.op else 600, args.num_hid,
                              1, False, .0, args.rnn)
    v_att = BiAttention(dataset.v_dim, args.num_hid, args.num_hid, args.gamma)
    # build and load pre-trained MAML model
    if args.maml:
        weight_path = args.RAD_dir + '/' + args.maml_model_path
        print('load initial weights MAML from: %s' % (weight_path))
        maml_v_emb = SimpleCNN(weight_path, args.eps_cnn, args.momentum_cnn)
    # build and load pre-trained Auto-encoder model
    if args.autoencoder:
        ae_v_emb = Auto_Encoder_Model()
        weight_path = args.RAD_dir + '/' + args.ae_model_path
        print('load initial weights DAE from: %s' % (weight_path))
        ae_v_emb.load_state_dict(torch.load(weight_path))
    # Loading tfidf weighted embedding
    if hasattr(args, 'tfidf'):
        w_emb = tfidf_loading(args.tfidf, w_emb, args)
    # Optional module: counter for BAN
    use_counter = args.use_counter if priotize_using_counter is None else priotize_using_counter
    if use_counter or priotize_using_counter:
        objects = 10  # minimum number of boxes
    if use_counter or priotize_using_counter:
        counter = Counter(objects)
    else:
        counter = None
    # init BAN residual network
    b_net = []
    q_prj = []
    c_prj = []
    for i in range(args.gamma):
        b_net.append(
            BCNet(dataset.v_dim, args.num_hid, args.num_hid, None, k=1))
        q_prj.append(FCNet([args.num_hid, args.num_hid], '', .2))
        if use_counter or priotize_using_counter:
            c_prj.append(FCNet([objects + 1, args.num_hid], 'ReLU', .0))
    # init classifier
    classifier = SimpleClassifier(args.num_hid, args.num_hid * 2,
                                  dataset.num_ans_candidates, args)
    # contruct VQA model and return
    if args.maml and args.autoencoder:
        return BAN_Model(dataset, w_emb, q_emb, v_att, b_net, q_prj, c_prj,
                         classifier, counter, args, maml_v_emb, ae_v_emb)
    elif args.maml:
        return BAN_Model(dataset, w_emb, q_emb, v_att, b_net, q_prj, c_prj,
                         classifier, counter, args, maml_v_emb, None)
    elif args.autoencoder:
        return BAN_Model(dataset, w_emb, q_emb, v_att, b_net, q_prj, c_prj,
                         classifier, counter, args, None, ae_v_emb)
    return BAN_Model(dataset, w_emb, q_emb, v_att, b_net, q_prj, c_prj,
                     classifier, counter, args, None, None)
Example #7
0
def build_ban_foil(dataset, num_hid, num_ans_candidates, op='', gamma=4):
    w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, .0, op)
    q_emb = QuestionEmbedding(300 if 'c' not in op else 600, num_hid, 1, False,
                              .0)
    v_att = BiAttention(dataset.v_dim, num_hid, num_hid, gamma)
    b_net = []
    q_prj = []
    c_prj = []
    objects = 10  # minimum number of boxes
    for i in range(gamma):
        b_net.append(BCNet(dataset.v_dim, num_hid, num_hid, None, k=1))
        q_prj.append(FCNet([num_hid, num_hid], '', .2))
        c_prj.append(FCNet([objects + 1, num_hid], 'ReLU', .0))
    classifier = SimpleClassifierFoil(num_hid, 64, num_ans_candidates)
    counter = Counter(objects)
    return BanModel(dataset, w_emb, q_emb, v_att, b_net, q_prj, c_prj,
                    classifier, counter, op, gamma)
Example #8
0
def build_ban(dataset, num_hid, op='', gamma=4, q_emb_type='bert', on_do_q=False, finetune_q=False):
    if 'bert' in q_emb_type:
        q_emb = BertModel.from_pretrained('bert-base-multilingual-cased')
        q_dim = 768
    elif 'rg' in q_emb_type:
        w_dim = 100
        q_dim = num_hid
        q_emb = RnnQuestionEmbedding(dataset.dictionary.ntoken, w_dim, q_dim, op)
    elif 'pkb' in q_emb_type:
        w_dim = 200
        q_dim = num_hid
        q_emb = RnnQuestionEmbedding(dataset.dictionary.ntoken, w_dim, q_dim, op)

    if 'bertrnn' in q_emb_type:
        q_emb = BertRnnQuestionEmbedding(q_emb, 200, num_hid, op)
        q_dim = num_hid

    if not finetune_q: # Freeze question embedding
        if isinstance(q_emb, BertModel):
            for p in q_emb.parameters():
                p.requires_grad_(False)
        else:
            for p in q_emb.w_emb.parameters():
                p.requires_grad_(False)
    if not on_do_q: # Remove dropout of question embedding
        for m in q_emb.modules():
            if isinstance(m, nn.Dropout):
                m.p = 0.

    v_att = BiAttention(dataset.v_dim, q_dim, num_hid, gamma)
    b_net = []
    q_prj = []
    c_prj = []
    objects = 10  # minimum number of boxes
    for i in range(gamma):
        b_net.append(BCNet(dataset.v_dim, q_dim, num_hid, None, k=1))
        q_prj.append(FCNet([num_hid, q_dim], '', .2))
        c_prj.append(FCNet([objects + 1, q_dim], 'ReLU', .0))
    classifiers = [SimpleClassifier(q_dim, num_hid * 2, dataset.num_ans_candidates, .5),
                   SimpleClassifier(q_dim, num_hid * 2, 1, .5)]
    counter = Counter(objects)
    return BanModel(dataset, q_emb, v_att, b_net, q_prj, c_prj, classifiers, counter, op, gamma)