class VQA:
    def __init__(self,folder="/",load=True):
        # Datasets
        if load:
            self.train_tuple = get_data_tuple(
                args.train, bs=args.batch_size, shuffle=True, drop_last=True,folder=folder
            )
            if args.valid != "":
                self.valid_tuple = get_data_tuple(
                    args.valid, bs=128,
                    shuffle=False, drop_last=False, folder=folder
                )
            else:
                self.valid_tuple = None
        
        # Model
#         self.model = VQAModel(self.train_tuple.dataset.num_answers)
        self.model = VQAModel(3129)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa, self.model,
                           label2ans=self.train_tuple.dataset.label2ans)
        
        # GPU options
        self.model = self.model.cuda()
        
        
        self.yes_index=425
        self.no_index=1403
        
        self.mask_yes = torch.zeros(len(self.indexlist)).cuda()
        self.mask_yes[self.yes_index]=1.0
        self.mask_yes[self.no_index]=1.0
        
        
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if load :
            if 'bert' in args.optim:
                batch_per_epoch = len(self.train_tuple.loader)
                t_total = int(batch_per_epoch * args.epochs)
                print("BertAdam Total Iters: %d" % t_total)
                from lxrt.optimization import BertAdam
                self.optim = BertAdam(list(self.model.parameters()),
                                      lr=args.lr,
                                      warmup=0.1,
                                      t_total=t_total)
            else:
                self.optim = args.optimizer(self.model.parameters(), args.lr)
            # Output Directory
            self.output = args.output
            os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader),ascii=True)) if args.tqdm else (lambda x: x)

        best_valid = 0.
        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, (ques_id, feats, boxes, ques, op, q1, q2, typetarget, q1typetarget, q2typetarget, yesnotypetargets, q1yntypetargets, q2yntypetargets, target, q1_target, q2_target) in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                feats, boxes, target, yntypetarget, typetarget = feats.cuda(), boxes.cuda(), target.cuda(), yesnotypetargets.cuda(), typetarget.cuda()
                
                op, q1typetarget, q2typetarget, q1yntypetargets, q2yntypetargets , q1_target, q2_target = op.cuda(), q1typetarget.cuda(), q2typetarget.cuda(), q1yntypetargets.cuda(), q2yntypetargets.cuda() , q1_target.cuda(), q2_target.cuda()
                
                
                
                logit = self.model(feats, boxes, sent)
                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)
                
                q1logit = self.model(feats, boxes, q1)
                q2logit = self.model(feats, boxes, q2)
                
                constraint_loss = self.constraint_loss(logit,q1logit,q2logit,op)

                loss = 0.5*loss + 0.5*constraint_loss

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans

            log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")
        return best_valid
    
    def rangeloss(self,x,lower,upper,lamb=4):
        mean = (lower+upper)/2
        sigma =  (upper-lower+0.00001)/lamb
        loss = 1 - torch.exp(-0.5*torch.pow(torch.div(x-mean,sigma),2))
        return loss.sum()
    
    def select_yesnoprobs(self,logit,x,op):
        op_mask = torch.eq(op,x)
        logit = logit[op_mask].view(-1,3129)
        logit_m =  logit * self.mask_yes 
        m = logit_m == 0
        logit_m = logit_m[~m].view(-1,2)
        logit_m = torch.softmax(logit_m,1)
        return logit_m.select(dim=1,index=0).view(-1,1)
        
    
    def constraintloss(self,logit,q1_logit,q2_logit,op):
        total_loss=torch.zeros([1]).cuda()
        for x in range(1,11):
            logit_m= self.select_yesnoprobs(logit,x,op)
            q1_logit_m= self.select_yesnoprobs(q1_logit,x,op)
            q2_logit_m= self.select_yesnoprobs(q2_logit,x,op)
            
            if logit_m.nelement()==0:
                continue

            ideal_logit_m = op_map[x](q1_logit_m,q2_logit_m)
            rangeloss = self.mseloss(logit_m,ideal_logit_m)
            total_loss+=rangeloss
        return total_loss
    

    def predict(self, eval_tuple: DataTuple, dump=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        for i, datum_tuple in tqdm(enumerate(loader),ascii=True,desc="Evaluating"):
#             ques_id, feats, boxes, sent = datum_tuple[:4]   # Avoid seeing ground truth
            ques_id, feats, boxes, ques, op, q1, q2, typetarget, q1typetarget, q2typetarget, yesnotypetargets, q1yntypetargets, q2yntypetargets, target, q1_target, q2_target = datum_tuple
    
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(feats, boxes, sent)
                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        quesid2ans = self.predict(eval_tuple, dump)
        return eval_tuple.evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, datum_tuple in enumerate(loader):
             ques_id, feats, boxes, ques, op, q1, q2, typetarget, q1typetarget, q2typetarget, yesnotypetargets, q1yntypetargets, q2yntypetargets, target, q1_target, q2_target = datum_tuple
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)
class VQA:
    def __init__(self, folder="/", load=True):
        # Datasets
        if load:
            self.train_tuple = get_data_tuple(args.train,
                                              bs=args.batch_size,
                                              shuffle=True,
                                              drop_last=True,
                                              folder=folder)
            if args.valid != "":
                self.valid_tuple = get_data_tuple(args.valid,
                                                  bs=128,
                                                  shuffle=False,
                                                  drop_last=False,
                                                  folder=folder)
            else:
                self.valid_tuple = None

        # Model
#         self.model = VQAModel(self.train_tuple.dataset.num_answers)
        self.model = VQAModel(3129, fn_type=args.fn_type)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Load IndexList of Answer to Type Map
        self.indexlist = json.load(open("data/vqa/indexlist.json"))
        indextensor = torch.cuda.LongTensor(self.indexlist)
        self.mask0 = torch.eq(indextensor, 0).float()
        self.mask1 = torch.eq(indextensor, 1).float()
        self.mask2 = torch.eq(indextensor, 2).float()

        self.mask_cache = {}

        self.yes_index = 425
        self.no_index = 1403

        self.mask_yes = torch.zeros(len(self.indexlist)).cuda()
        self.mask_yes[self.yes_index] = 1.0
        self.mask_yes[self.no_index] = 1.0

        # Loss and Optimizer

        self.logsoftmax = nn.LogSoftmax()
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax()

        self.bceloss = nn.BCELoss()
        self.nllloss = nn.NLLLoss()
        self.mseloss = nn.MSELoss()

        self.bce_loss = nn.BCEWithLogitsLoss()
        self.ce_loss = nn.CrossEntropyLoss()

        if load:
            if 'bert' in args.optim:
                batch_per_epoch = len(self.train_tuple.loader)
                t_total = int(batch_per_epoch * args.epochs)
                print("BertAdam Total Iters: %d" % t_total)
                from lxrt.optimization import BertAdam
                self.optim = BertAdam(list(self.model.parameters()),
                                      lr=args.lr,
                                      warmup=0.1,
                                      t_total=t_total)
            else:
                self.optim = args.optimizer(self.model.parameters(), args.lr)
            # Output Directory
            self.output = args.output
            os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader), ascii=True)
                        ) if args.tqdm else (lambda x: x)

        best_valid = 0.
        for epoch in range(args.epochs):
            quesid2ans = {}

            epoch_loss = 0
            type_loss = 0
            yn_loss = 0
            const_loss = 0
            total_items = 0

            for i, (ques_id, feats, boxes, ques, op, q1, q2, typetarget,
                    q1typetarget, q2typetarget, yesnotypetargets,
                    q1yntypetargets, q2yntypetargets, target, q1_target,
                    q2_target) in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                feats, boxes, target, yntypetarget, typetarget = feats.cuda(
                ), boxes.cuda(), target.cuda(), yesnotypetargets.cuda(
                ), typetarget.cuda()

                op, q1typetarget, q2typetarget, q1yntypetargets, q2yntypetargets, q1_target, q2_target = op.cuda(
                ), q1typetarget.cuda(), q2typetarget.cuda(
                ), q1yntypetargets.cuda(), q2yntypetargets.cuda(
                ), q1_target.cuda(), q2_target.cuda()

                # The actual question
                logit, type_logit = self.model_forward(feats, boxes, ques,
                                                       target, yntypetarget)

                #                 print(logit,target)

                loss = self.bceloss(logit, target)
                loss = loss * logit.size(1)

                loss_type = self.nllloss(type_logit, typetarget)
                loss_type = loss_type * type_logit.size(1)

                #                 loss_yn = self.bce_loss(yn_type_logit,yntypetarget)
                #                 loss_yn = loss_yn*yn_type_logit.size(1)

                # Q1 and Q2 Prediction, no loss for these predictions only constraint loss for these guys

                q1_logit, q1_type_logit = self.model_forward(
                    feats, boxes, q1, q1_target, q1yntypetargets)

                q2_logit, q2_type_logit = self.model_forward(
                    feats, boxes, q2, q2_target, q2yntypetargets)

                constraint_loss = self.constraintloss(logit, q1_logit,
                                                      q2_logit, op)

                # Final Loss

                #                 print(loss, loss_type,loss_yn, constraint_loss)
                epoch_loss += loss.item()
                type_loss += loss_type.item()
                #                 yn_loss+=loss_yn.item()
                const_loss += constraint_loss.item()
                total_items += 1

                loss = 0.5 * loss + 0.25 * loss_type + 0.25 * constraint_loss

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans

            log_str = "\nEpoch %d: Train %0.2f\n" % (
                epoch, evaluator.evaluate(quesid2ans) * 100.)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.) + \
                           "Epoch Losses: QA : %0.5f, YN : %0.5f, Type : %0.5f, Const : %0.5f\n"%(epoch_loss/total_items,yn_loss/total_items,type_loss/total_items, const_loss/total_items)

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")
        return best_valid

    def model_forward(self, feats, boxes, sent, target, yntypetarget):
        logit, type_logit = self.model(feats, boxes, sent)
        assert logit.dim() == target.dim() == 2

        logit = self.sigmoid(logit)
        type_logit_soft = self.softmax(type_logit)
        type_logit = self.logsoftmax(type_logit)

        logit = self.calculatelogits(logit, type_logit_soft, "train")
        return logit, type_logit

    def get_masks(self, mode, batch):
        #         key = mode+str(batch)
        #         if key in self.mask_cache:
        #             return self.mask_cache[key]

        mask0 = self.mask0.repeat([batch, 1])
        mask1 = self.mask1.repeat([batch, 1])
        mask2 = self.mask2.repeat([batch, 1])
        return [mask0, mask1, mask2]


#         self.mask_cache[key] = [mask0,mask1,mask2]
#         return self.mask_cache[key]

    def calculatelogits(self, anspreds, typepreds, mode):
        batch = anspreds.size()[0]
        mask0, mask1, mask2 = self.get_masks(mode, batch)
        anspreds0 = anspreds * mask0 * typepreds.select(
            dim=1, index=0).reshape([batch, 1]).repeat([1, 3129])
        anspreds1 = anspreds * mask1 * typepreds.select(
            dim=1, index=1).reshape([batch, 1]).repeat([1, 3129])
        anspreds2 = anspreds * mask2 * typepreds.select(
            dim=1, index=2).reshape([batch, 1]).repeat([1, 3129])
        nanspreds = anspreds0 + anspreds1 + anspreds2
        return nanspreds

    def rangeloss(self, x, lower, upper, lamb=4):
        mean = (lower + upper) / 2
        sigma = (upper - lower + 0.00001) / lamb
        loss = 1 - torch.exp(-0.5 * torch.pow(torch.div(x - mean, sigma), 2))
        return loss.sum()

    def select_yesnoprobs(self, logit, x, op):
        op_mask = torch.eq(op, x)
        logit = logit[op_mask].view(-1, 3129)
        logit_m = logit * self.mask_yes
        m = logit_m == 0
        logit_m = logit_m[~m].view(-1, 2)
        logit_m = torch.softmax(logit_m, 1)
        return logit_m.select(dim=1, index=0).view(-1, 1)

    def constraintloss(self, logit, q1_logit, q2_logit, op):
        total_loss = torch.zeros([1]).cuda()
        for x in range(1, 11):
            logit_m = self.select_yesnoprobs(logit, x, op)
            q1_logit_m = self.select_yesnoprobs(q1_logit, x, op)
            q2_logit_m = self.select_yesnoprobs(q2_logit, x, op)

            if logit_m.nelement() == 0:
                continue

            ideal_logit_m = op_map[x](q1_logit_m, q2_logit_m)
            rangeloss = self.mseloss(logit_m, ideal_logit_m)
            total_loss += rangeloss
        return total_loss

    def predict(self, eval_tuple: DataTuple, dump=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}

        type_accuracy = 0.0
        yn_type_accuracy = 0.0
        num_batches = 0
        for i, datum_tuple in tqdm(enumerate(loader),
                                   ascii=True,
                                   desc="Evaluating"):
            #             ques_id, feats, boxes, sent, typed_target, yesnotypetargets, target = datum_tuple   # Avoid seeing ground truth
            ques_id, feats, boxes, ques, op, q1, q2, typetarget, q1typetarget, q2typetarget, yesnotypetargets, q1yntypetargets, q2yntypetargets, target, q1_target, q2_target = datum_tuple
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                logit, typelogit = self.model(feats, boxes, ques)

                logit = self.sigmoid(logit)
                type_logit_soft = self.softmax(typelogit)
                logit = self.calculatelogits(logit, type_logit_soft, "predict")

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans

                type_accuracy += torch.mean(
                    torch.eq(
                        typelogit.argmax(dim=1).cuda(),
                        typetarget.cuda()).float().cuda()).cpu().item()
                #                 yn_type_accuracy+= torch.mean(torch.eq((yn_type_logit>0.5).float().cuda(),yesnotypetargets.cuda().float()).float().cuda()).cpu().item()
                num_batches += 1

        print("Type Accuracy:", type_accuracy / num_batches)
        print("YN Accuracy:", yn_type_accuracy / num_batches)

        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        quesid2ans = self.predict(eval_tuple, dump)
        return eval_tuple.evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, ques, op, q1, q2, typetarget, q1typetarget, q2typetarget, yesnotypetargets, q1yntypetargets, q2yntypetargets, target, q1_target, q2_target = datum_tuple
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        model_to_save = self.model.module if hasattr(self.model,
                                                     "module") else self.model
        torch.save(model_to_save.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)

        self.model.load_state_dict(state_dict)
Exemplo n.º 3
0
class VQA:
    def __init__(self, folder="/", load=True):
        # Datasets
        if load:
            self.train_tuple = get_data_tuple(args.train,
                                              bs=args.batch_size,
                                              shuffle=True,
                                              drop_last=True,
                                              folder=folder)
            if args.valid != "":
                self.valid_tuple = get_data_tuple(args.valid,
                                                  bs=512,
                                                  shuffle=False,
                                                  drop_last=False,
                                                  folder=folder)
            else:
                self.valid_tuple = None

        # Model
#         self.model = VQAModel(self.train_tuple.dataset.num_answers)
        self.model = VQAModel(len(self.train_tuple.dataset.label2ans),
                              fn_type=args.fn_type)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Load IndexList of Answer to Type Map
        self.indexlist = json.load(
            open(
                "/data/datasets/vqa_mutant/data/vqa/mutant_l2a/mutant_merge_indexlist.json"
            ))

        print("Length of Masks", len(self.indexlist), flush=True)

        indextensor = torch.cuda.LongTensor(self.indexlist)
        self.mask0 = torch.eq(indextensor, 0).float()
        self.mask1 = torch.eq(indextensor, 1).float()
        self.mask2 = torch.eq(indextensor, 2).float()
        self.mask3 = torch.eq(indextensor, 3).float()

        self.mask_cache = {}

        # Loss and Optimizer

        self.logsoftmax = nn.LogSoftmax()
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax()

        self.bceloss = nn.BCELoss()
        self.nllloss = nn.NLLLoss()

        self.bce_loss = nn.BCEWithLogitsLoss()
        self.ce_loss = nn.CrossEntropyLoss()

        if load:
            if 'bert' in args.optim:
                batch_per_epoch = len(self.train_tuple.loader)
                t_total = int(batch_per_epoch * args.epochs)
                print("BertAdam Total Iters: %d" % t_total)
                from lxrt.optimization import BertAdam
                self.optim = BertAdam(list(self.model.parameters()),
                                      lr=args.lr,
                                      warmup=0.1,
                                      t_total=t_total)
            else:
                self.optim = args.optimizer(self.model.parameters(), args.lr)
            # Output Directory
            self.output = args.output
            os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader), ascii=True)
                        ) if args.tqdm else (lambda x: x)

        total_steps = len(loader)
        eval_every = int(0.2 * total_steps)

        best_valid = 0.
        best_i = 0
        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, (ques_id, feats, boxes, sent, typetarget,
                    target) in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                feats, boxes, target, typetarget = feats.cuda(), boxes.cuda(
                ), target.cuda(), typetarget.cuda()
                logit, type_logit = self.model(feats, boxes, sent)
                assert logit.dim() == target.dim() == 2

                logit = self.sigmoid(logit)
                type_logit_soft = self.softmax(type_logit)
                type_logit = self.logsoftmax(type_logit)

                logit = self.calculatelogits(logit, type_logit_soft, "train")

                loss = self.bceloss(logit, target)
                loss = loss * logit.size(1)

                loss_type = self.nllloss(type_logit, typetarget)
                loss_type = loss_type * type_logit.size(1)

                loss = 0.9 * loss + 0.1 * loss_type

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans

                if ((i + 1) % eval_every
                        == 0) and self.valid_tuple is not None:
                    log_str = "\nEpoch %d, Step %d: Train %0.2f\n" % (
                        epoch, i, evaluator.evaluate(quesid2ans) * 100.)

                    if self.valid_tuple is not None:  # Do Validation
                        valid_score = self.evaluate(eval_tuple)
                        if valid_score > best_valid:
                            best_valid = valid_score
                            best_i = i
                            self.save("BEST")

                        log_str += "Epoch %d, Step %d: Valid %0.2f\n" % (epoch,i, valid_score * 100.) + \
                                "Epoch %d, Best Step %d: Best %0.2f\n" % (epoch,best_i, best_valid * 100.)

                    print(log_str, end='')

                    with open(self.output + "/log.log", 'a') as f:
                        f.write(log_str)
                        f.flush()

        self.save("LAST")
        return best_valid

    def get_masks(self, mode, batch):
        key = mode + str(batch)
        if key in self.mask_cache:
            return self.mask_cache[key]

        mask0 = self.mask0.repeat([batch, 1])
        mask1 = self.mask1.repeat([batch, 1])
        mask2 = self.mask2.repeat([batch, 1])
        mask3 = self.mask3.repeat([batch, 1])

        self.mask_cache[key] = [mask0, mask1, mask2, mask3]
        return self.mask_cache[key]

    def calculatelogits(self, anspreds, typepreds, mode):
        batch = anspreds.size()[0]
        mask0, mask1, mask2, mask3 = self.get_masks(mode, batch)
        replen = len(self.train_tuple.dataset.label2ans)
        anspreds0 = anspreds * mask0 * typepreds.select(
            dim=1, index=0).reshape([batch, 1]).repeat([1, replen])
        anspreds1 = anspreds * mask1 * typepreds.select(
            dim=1, index=1).reshape([batch, 1]).repeat([1, replen])
        anspreds2 = anspreds * mask2 * typepreds.select(
            dim=1, index=2).reshape([batch, 1]).repeat([1, replen])
        anspreds3 = anspreds * mask3 * typepreds.select(
            dim=1, index=3).reshape([batch, 1]).repeat([1, replen])
        nanspreds = anspreds0 + anspreds1 + anspreds2 + anspreds3
        return nanspreds

    def predict(self, eval_tuple: DataTuple, dump=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}

        type_accuracy = 0.0
        num_batches = 0
        for i, datum_tuple in tqdm(enumerate(loader),
                                   ascii=True,
                                   desc="Evaluating"):
            ques_id, feats, boxes, sent, typed_target, target = datum_tuple  # Avoid seeing ground truth
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                logit, typelogit = self.model(feats, boxes, sent)

                logit = self.sigmoid(logit)
                type_logit_soft = self.softmax(typelogit)
                logit = self.calculatelogits(logit, type_logit_soft, "predict")

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans


#               type_accuracy+= torch.mean(torch.eq((typelogit>0.5).int().cuda(),typed_target.cuda()).float().cuda()).cpu().item()
                type_accuracy += torch.mean(
                    torch.eq(
                        typelogit.argmax(dim=1).cuda(),
                        typed_target.cuda()).float().cuda()).cpu().item()
                num_batches += 1

        print("Type Accuracy:", type_accuracy / num_batches)

        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        quesid2ans = self.predict(eval_tuple, dump)
        return eval_tuple.evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent, typetarget, target = datum_tuple
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        model_to_save = self.model.module if hasattr(self.model,
                                                     "module") else self.model
        torch.save(model_to_save.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)
Exemplo n.º 4
0
class VQA:
    def __init__(self, folder="/", load=True):
        # Datasets
        if load:
            self.train_tuple = get_data_tuple(args.train,
                                              bs=args.batch_size,
                                              shuffle=True,
                                              drop_last=True,
                                              folder=folder)
            if args.valid != "":
                self.valid_tuple = get_data_tuple(args.valid,
                                                  bs=128,
                                                  shuffle=False,
                                                  drop_last=False,
                                                  folder=folder,
                                                  nops=args.nops)
            else:
                self.valid_tuple = None

        # Model
        self.model = VQAModel(self.train_tuple.dataset.num_answers)
        #         is_cp=False
        #         if "vqacpv2" in folder:
        #             is_cp=True
        #         if not is_cp:
        #             self.model = VQAModel(3129)
        #         else:
        #             self.model = VQAModel(16039)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        ans_embed = np.load(
            "/data/datasets/vqa_mutant/data/vqa/mutant_l2a/answer_embs.npy"
        ) + 1e-8
        ans_embed = torch.tensor(ans_embed).cuda()
        self.ans_embed = torch.nn.functional.normalize(ans_embed, dim=1)
        self.embed_cache = {}

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if load:
            if 'bert' in args.optim:
                batch_per_epoch = len(self.train_tuple.loader)
                t_total = int(batch_per_epoch * args.epochs)
                print("BertAdam Total Iters: %d" % t_total)
                from lxrt.optimization import BertAdam
                self.optim = BertAdam(list(self.model.parameters()),
                                      lr=args.lr,
                                      warmup=0.1,
                                      t_total=t_total)
            else:
                self.optim = args.optimizer(self.model.parameters(), args.lr)
            # Output Directory
            self.output = args.output
            os.makedirs(self.output, exist_ok=True)

        self.cos = nn.CosineSimilarity()

    def get_answer_embs(self, bz):
        if bz in self.embed_cache:
            return self.embed_cache[bz]
        emb = torch.stack([self.ans_embed] * bz)
        self.embed_cache[bz] = emb
        return emb

    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader), ascii=True)
                        ) if args.tqdm else (lambda x: x)

        best_valid = 0.
        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, (ques_id, feats, boxes, sent, target,
                    gold_embs) in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                feats, boxes, target, gold_emb = feats.cuda(), boxes.cuda(
                ), target.cuda(), gold_embs.cuda()
                gold_emb = torch.nn.functional.normalize(gold_emb, dim=1)
                gen_embs, logits = self.model(feats, boxes, sent)
                #                 all_ans_embs = self.get_answer_embs(gen_embs.shape[0])
                all_ans_embs = self.model.emb_proj(self.ans_embed)
                all_ans_embs = torch.stack([all_ans_embs] * gen_embs.shape[0])
                gold_emb = self.model.emb_proj(gold_emb)

                cos = nn.CosineSimilarity(dim=1)
                positive_dist = cos(gen_embs, gold_emb)  # shape b,k;b,k-> b
                gen_embs = torch.cat([gen_embs.unsqueeze(1)] *
                                     all_ans_embs.shape[1],
                                     dim=1)
                cos = nn.CosineSimilarity(dim=2)
                d_logit = cos(gen_embs, all_ans_embs)

                #                 print(logit,positive_dist,flush=True)

                num = torch.exp(positive_dist).squeeze(-1)
                # print(num,num.shape,flush=True)
                den = torch.exp(d_logit).sum(-1)
                # print(den,den.shape,flush=True)
                loss = -1 * torch.log(num / den)
                loss = loss.mean() * d_logit.size(1)

                assert logits.dim() == target.dim() == 2
                acloss = self.bce_loss(logits, target)
                acloss = acloss * logits.size(1)

                loss = acloss + loss

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                score, label = logits.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans

            log_str = "\nEpoch %d: Train %0.2f\n" % (
                epoch, evaluator.evaluate(quesid2ans) * 100.)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")
        return best_valid

    def predict(self, eval_tuple: DataTuple, dump=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        for i, datum_tuple in tqdm(enumerate(loader),
                                   ascii=True,
                                   desc="Evaluating"):
            ques_id, feats, boxes, sent = datum_tuple[:
                                                      4]  # Avoid seeing ground truth
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                embs, logits = self.model(feats, boxes, sent)
                #                 all_ans_embs = self.model.emb_proj(self.ans_embed)
                #                 all_ans_embs  = torch.stack([all_ans_embs]*embs.shape[0])
                #                 logit = torch.einsum("bj,bkj->bk",embs,all_ans_embs)
                score, label = logits.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        quesid2ans = self.predict(eval_tuple, dump)
        return eval_tuple.evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, target, emb) in enumerate(loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        model_to_save = self.model.module if hasattr(self.model,
                                                     "module") else self.model
        torch.save(model_to_save.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)
Exemplo n.º 5
0
class NLVR2:
    def __init__(self):
        self.train_tuple = get_tuple(
            args.train, bs=args.batch_size, shuffle=True, drop_last=True
        )
        if args.valid != "":
            valid_bsize = 2048 if args.multiGPU else 512
            self.valid_tuple = get_tuple(
                args.valid, bs=valid_bsize,
                shuffle=False, drop_last=False
            )
        else:
            self.valid_tuple = None

        self.momentum = 0.9995
        self.model = NLVR2Model()
        self.siam_model = copy.deepcopy(self.model)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
            self.siam_model.lxrt_encoder.load(args.load_lxmert)

        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa, self.model,
                           label2ans=self.train_tuple.dataset.label2ans)
            load_lxmert_qa(args.load_lxmert_qa, self.siam_model,
                       label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()
            self.siam_model.lxrt_encoder.multi_gpu()
        self.model = self.model.cuda()
        self.siam_model = self.siam_model.cuda()

        # Losses and optimizer
        self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1)
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(list(self.model.parameters()), args.lr)

        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def update_ema_variables(self):
        #        pdb.set_trace()
        # Use the true average until the exponential average is more correct
        alpha = self.momentum
        ema_model = self.siam_model
        model = self.model
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(alpha).add_(1 - alpha, param.data)

        # self.siam_model=ema_model


    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x)

        best_valid1 = 0.
        best_valid2 = 0.
        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, (ques_id, feats, boxes, sent, label) in iter_wrapper(enumerate(loader)):
                self.model.train()

                self.optim.zero_grad()
                feats, boxes, label = feats.cuda(), boxes.cuda(), label.cuda()
                logit = self.model(feats, boxes, sent)

                loss = self.mce_loss(logit, label)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()
                self.update_ema_variables()

                score, predict = logit.max(1)
                for qid, l in zip(ques_id, predict.cpu().numpy()):
                    quesid2ans[qid] = l

            log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.)

            if self.valid_tuple is not None:  # Do Validation
                valid_score1, valid_score2 = self.evaluate(eval_tuple)
                if valid_score1 > best_valid1:
                    best_valid1 = valid_score1
                    self.save1("BEST")

                if valid_score2 > best_valid2:
                    best_valid2 = valid_score2
                    self.save2("BEST_siam")

                log_str += "Epoch %d: Valid1 %0.2f\n" % (epoch, valid_score1 * 100.) + \
                           "Epoch %d: Best1 %0.2f\n" % (epoch, best_valid1 * 100.)

                log_str += "Epoch %d: Valid2 %0.2f\n" % (epoch, valid_score2 * 100.) + \
                           "Epoch %d: Best2 %0.2f\n" % (epoch, best_valid2 * 100.)

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save1("LAST1")
        self.save2("LAST2")

    def predict(self, eval_tuple: DataTuple, dump=None):
        self.model.eval()
        self.siam_model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans1 = {}
        quesid2ans2 = {}
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent = datum_tuple[:4]   # avoid handling target
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(feats, boxes, sent)
                score1, predict1 = logit.max(1)
                for qid, l in zip(ques_id, predict1.cpu().numpy()):
                    quesid2ans1[qid] = l

        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent = datum_tuple[:4]   # avoid handling target
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                logit2 = self.siam_model(feats, boxes, sent)
                score2, predict2 = logit2.max(1)
                for qid, l in zip(ques_id, predict2.cpu().numpy()):
                    quesid2ans2[qid] = l

        if dump is not None:
            evaluator.dump_result(quesid2ans1, dump)
            evaluator.dump_result(quesid2ans2, dump)
        return quesid2ans1, quesid2ans2

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        dset, loader, evaluator = eval_tuple
        quesid2ans1, quesid2ans2 = self.predict(eval_tuple, dump)
        return evaluator.evaluate(quesid2ans1), evaluator.evaluate(quesid2ans2)

    def save1(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def save2(self, name):
        torch.save(self.siam_model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)
Exemplo n.º 6
0
class VQA:
    def __init__(self):
        # Datasets
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True)
        if args.valid != "":
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=1024,
                                              shuffle=False,
                                              drop_last=False)
        else:
            self.valid_tuple = None

        # Model
        self.model = VQAModel(self.train_tuple.dataset.num_answers)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        best_valid = 0.
        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, (ques_id, feats, boxes, sent,
                    target) in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda(
                )
                logit = self.model(feats, boxes, sent)
                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans

            log_str = "\nEpoch %d: Train %0.2f\n" % (
                epoch, evaluator.evaluate(quesid2ans) * 100.)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")

    def predict(self, eval_tuple: DataTuple, dump=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent = datum_tuple[:
                                                      4]  # Avoid seeing ground truth
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(feats, boxes, sent)
                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        quesid2ans = self.predict(eval_tuple, dump)
        return eval_tuple.evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)
Exemplo n.º 7
0
class VQA:
    def __init__(self):
        # Datasets
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True)
        if args.valid != "":
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=1024,
                                              shuffle=False,
                                              drop_last=False)
        else:
            self.valid_tuple = None

        # Model
        self.model = VQAModel(self.train_tuple.dataset.num_answers)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def train(self,
              train_tuple,
              eval_tuple,
              adversarial=False,
              adv_batch_prob=0.0,
              attack_name=None,
              attack_params={}):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)
        use_adv_batch = False

        best_valid = 0.

        for epoch in range(args.epochs):
            quesid2ans = {}
            # Count the number of batches that were adversarially perturbed
            n_adv_batches = 0
            for i, (ques_id, feats, boxes, sent,
                    target) in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda(
                )

                # If doing adversarial training, perturb input features
                # with probability adv_batch_prob
                if adversarial:
                    rand = random.uniform(0, 1)
                    use_adv_batch = rand <= adv_batch_prob
                if use_adv_batch:
                    # Create adversary from given class name and parameters
                    n_adv_batches += 1
                    AdversaryClass_ = getattr(advertorch_module, attack_name)
                    adversary = AdversaryClass_(
                        lambda x: self.model(x, boxes, sent),
                        loss_fn=self.bce_loss,
                        **attack_params)
                    # Perturb feats using adversary
                    feats = adversary.perturb(feats, target)

                logit = self.model(feats, boxes, sent)
                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans

            log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.) + \
                        "Epoch %d: Num adversarial batches %d / %d\n" % (epoch, n_adv_batches, i+1)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")

    def predict(self, eval_tuple: DataTuple, dump=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent = datum_tuple[:
                                                      4]  # Avoid seeing ground truth
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(feats, boxes, sent)
                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def adversarial_predict(self,
                            eval_tuple: DataTuple,
                            dump=None,
                            attack_name='GradientAttack',
                            attack_params={}):
        """
        Predict the answers to questions in a data split, but
        using a specified adversarial attack on the inputs.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        sim_trace = []  # Track avg cos similarity across batches
        for i, datum_tuple in enumerate(tqdm(loader)):
            ques_id, feats, boxes, sent, target = datum_tuple
            feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda()

            # Create adversary from given class name and parameters
            AdversaryClass_ = getattr(advertorch_module, attack_name)
            adversary = AdversaryClass_(lambda x: self.model(x, boxes, sent),
                                        loss_fn=self.bce_loss,
                                        **attack_params)

            # Perturb feats using adversary
            feats_adv = adversary.perturb(feats, target)

            # Compute average cosine similarity between true
            # and perturbed features
            sim_trace.append(self.avg_cosine_sim(feats, feats_adv))

            # Compute prediction on adversarial examples
            with torch.no_grad():
                feats_adv = feats_adv.cuda()
                logit = self.model(feats_adv, boxes, sent)
                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        print(
            f"Average cosine similarity across batches: {torch.mean(torch.Tensor(sim_trace))}"
        )
        return quesid2ans

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        quesid2ans = self.predict(eval_tuple, dump)
        return eval_tuple.evaluator.evaluate(quesid2ans)

    def adversarial_evaluate(self,
                             eval_tuple: DataTuple,
                             dump=None,
                             attack_name='GradientAttack',
                             attack_params={}):
        """Evaluate model on adversarial inputs"""
        quesid2ans = self.adversarial_predict(eval_tuple, dump, attack_name,
                                              attack_params)
        return eval_tuple.evaluator.evaluate(quesid2ans)

    def avg_cosine_sim(self, feats: torch.Tensor, feats_adv: torch.Tensor):
        """Computes the average cosine similarity between true and adversarial examples"""
        return nn.functional.cosine_similarity(feats, feats_adv, dim=-1).mean()

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)
Exemplo n.º 8
0
class VQA:
    def __init__(self):
        # Model
        self.model = VQAModel()

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)

        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa, self.model)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()

        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True)

        self.valid_tuple = get_data_tuple(args.valid,
                                          bs=1024,
                                          shuffle=False,
                                          drop_last=False)

        self.test_tuple = get_data_tuple(args.test,
                                         bs=1024,
                                         shuffle=False,
                                         drop_last=False)

        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple, test_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        start_epoch = -1
        best_valid = 0.

        if args.RESUME:
            path_checkpoint = args.path_checkpoint  #  ../../model/checkpoint/lxr_best_%s.pth  # 断点路径
            checkpoint = torch.load(path_checkpoint)  # 加载断点

            self.model.load_state_dict(checkpoint['model_state_dict'])

            self.optim.load_state_dict(checkpoint['optimizer_state_dict'])
            start_epoch = checkpoint['epoch']

        for epoch in range(start_epoch + 1, args.epochs):
            for i, (feats, boxes, sent, _,
                    _) in iter_wrapper(enumerate(loader)):

                # construct negative exmaples
                bs = len(sent)
                index_list = list(range(bs))

                sent_negative = []
                for j in range(bs):
                    choice = random.choice(list(set(index_list) - {j}))
                    sent_negative.append(sent[choice])

                sent = sent + sent_negative

                feats = torch.cat([feats, feats])
                boxes = torch.cat([boxes, boxes])

                target = torch.ones(bs, 1)
                target_negative = torch.zeros(bs, 1)
                target = torch.cat([target, target_negative])

                self.model.train()
                self.optim.zero_grad()

                feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda(
                )
                logit = self.model(feats, boxes, sent)
                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                batch_score = accuracy(logit, target)

                if i % 500 == 0:
                    print('epoch {}, Step {}/{}, Loss: {}'.format(
                        epoch, i, len(loader), loss.item()))

            log_str = "\nEpoch %d: Loss: %0.4f Train %0.2f\n" % (
                epoch, loss.item(), batch_score)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple, epoch)

                self.save_checkpoint(epoch)
                self.save(epoch)
                self.test_output(test_tuple, epoch)
                print('output done!')
                if valid_score > best_valid:
                    best_valid = valid_score

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score ) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid )

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")

    def predict(self, eval_tuple: DataTuple, epoch=0):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        preds = []
        query_ids = []
        product_ids = []
        with torch.no_grad():
            for i, datum_tuple in enumerate(loader):
                feats, boxes, sent, ques_id, produce_id = datum_tuple  # Avoid seeing ground truth
                query_ids.extend(ques_id)
                product_ids.extend(produce_id)
                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(feats, boxes, sent)
                logit = torch.sigmoid(logit)
                preds.append(logit.cpu().numpy())
        preds = np.concatenate(preds)

        # deal with format
        query2product = collections.defaultdict(list)
        for i, query_id in enumerate(query_ids):
            pred = preds[i]
            product_id = product_ids[i]
            query2product[str(query_id.item())].append(
                [pred.tolist()[0], str(product_id.item())])

        with open('../../user_data/lxmert_model/result/val/val_%s.txt' % epoch,
                  'w') as f:
            for q, p in query2product.items():
                q = str(q)
                p = str(p)
                f.write(q + ',' + p + '\n')
            f.close()

        with open('submission.csv', 'w') as f:
            f.write('query-id,product1,product2,product3,product4,product5\n')
            for q, p in query2product.items():
                p = sorted(p, key=lambda x: x[0], reverse=True)
                o = ','.join([p[i][1] for i in range(5)])
                f.write(q + ',' + o + '\n')

        os.system('python eval.py submission.csv')
        score = json.load(open('score.json'))["score"]
        return score

    def test_output(self, eval_tuple: DataTuple, epoch=0):
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        preds = []
        query_ids = []
        product_ids = []
        with torch.no_grad():
            for i, datum_tuple in enumerate(loader):
                feats, boxes, sent, ques_id, produce_id = datum_tuple  # Avoid seeing ground truth
                query_ids.extend(ques_id)
                product_ids.extend(produce_id)
                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(feats, boxes, sent)
                logit = torch.sigmoid(logit)
                preds.append(logit.cpu().numpy())
        preds = np.concatenate(preds)

        # deal with format
        query2product = collections.defaultdict(list)
        for i, query_id in enumerate(query_ids):
            pred = preds[i]
            product_id = product_ids[i]
            query2product[str(query_id.item())].append(
                [pred.tolist()[0], str(product_id.item())])

        print(os.getcwd())

        with open(
                '../../user_data/lxmert_model/result/test/test_%s.txt' % epoch,
                'w') as f:
            for q, p in query2product.items():
                q = str(q)
                p = str(p)
                f.write(q + ',' + p + '\n')
            f.close()

    def evaluate(self, eval_tuple: DataTuple, epoch=0):
        """Evaluate all data in data_tuple."""
        score = self.predict(eval_tuple, epoch)
        return score  #eval_tuple.evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, epoch):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % (str(epoch))))

    def save_checkpoint(self, epoch):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optim.state_dict(),
        }

        if not os.path.isdir('../../user_data/lxmert_model/checkpoint'):
            os.mkdir('../../user_data/lxmert_model/checkpoint')
        torch.save(checkpoint,
                   '../../user_data/lxmert_model/checkpoint/lxr_best.pth')

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)
Exemplo n.º 9
0
class CAP:
    def __init__(self):
        # Datasets
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True)
        if args.valid != "":
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=1024,
                                              shuffle=False,
                                              drop_last=False)
        else:
            self.valid_tuple = None

        # Model
        self.model = CAPModel(self.train_tuple.dataset.num_answers)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        # self.bce_loss = nn.BCEWithLogitsLoss()
        self.CrossEntropyLoss = nn.CrossEntropyLoss()
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        best_valid = 0.
        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, (ques_id, feats, boxes, sent,
                    target) in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda(
                )
                logit, out, input_ids, _ = self.model(feats, boxes, sent)
                assert logit.dim() == target.dim() == 2
                # loss = self.bce_loss(logit, target)
                # loss = loss * logit.size(1)
                loss = self.CrossEntropyLoss(out.view(-1, 30000),
                                             input_ids.view(-1))

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans

            log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, loss.item())

            #if self.valid_tuple is not None:  # Do Validation
            #valid_score = self.evaluate(eval_tuple)
            #if valid_score > best_valid:
            #best_valid = valid_score
            #self.save("BEST")

            #    log_str += "Epoch %d: Valid %0.2f\n" % (epoch, loss.item())

            print(log_str, end='')
            #print(sent)
            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")

    def predict(self, eval_tuple: DataTuple, dump=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        captions = {}
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent = datum_tuple[:
                                                      4]  # Avoid seeing ground truth
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                logit, _, _, output_sents = self.model(feats, boxes, sent)
                #print(len(output_sents))
                #print(feats.shape)
                # dist = Categorical(logits=F.log_softmax(logit[0], dim=-1))
                # pred_idxs = dist.sample().tolist()
                # label = BertTokenizer.convert_ids_to_tokens(pred_idxs)
                score, label = logit.max(1)
                #dset.convert_ids_to_tokens
                #for qid, l in zip(ques_id, label.cpu().numpy()):
                # ans = dset.label2ans[l]
                #ans = l
                #quesid2ans[qid.item()] = ans
                for qid, caption in zip(ques_id, output_sents):
                    words = tokenizer.convert_ids_to_tokens(caption)
                    captions[qid.item()] = [" ".join(words)]
                    #print(type(qid))
                    #print(type(caption))
        if dump is not None:
            evaluator.dump_result(captions, dump)
        return captions

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)
Exemplo n.º 10
0
class GQA:
    def __init__(self):
        self.train_tuple = get_tuple(
            args.train, bs=args.batch_size, shuffle=True, drop_last=True
        )
        if args.valid != "":
            valid_bsize = 2048 if args.multiGPU else 512
            self.valid_tuple = get_tuple(
                args.valid, bs=valid_bsize, shuffle=False, drop_last=False
            )
        else:
            self.valid_tuple = None

        self.model = GQAModel(self.train_tuple.dataset.num_answers)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(
                args.load_lxmert_qa,
                self.model,
                label2ans=self.train_tuple.dataset.label2ans,
            )

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Losses and optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1)
        if "bert" in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam

            self.optim = BertAdam(
                list(self.model.parameters()), lr=args.lr, warmup=0.1, t_total=t_total
            )
        else:
            self.optim = args.optimizer(list(self.model.parameters()), args.lr)

        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (
            (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x)
        )

        best_valid = 0.0
        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, (ques_id, feats, boxes, sent, target) in iter_wrapper(
                enumerate(loader)
            ):

                self.model.train()
                self.optim.zero_grad()

                feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda()
                logit = self.model(feats, boxes, sent)
                assert logit.dim() == target.dim() == 2
                if args.mce_loss:
                    max_value, target = target.max(1)
                    loss = self.mce_loss(logit, target) * logit.size(1)
                else:
                    loss = self.bce_loss(logit, target)
                    loss = loss * logit.size(1)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
                self.optim.step()

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid] = ans

            log_str = "\nEpoch %d: Train %0.2f\n" % (
                epoch,
                evaluator.evaluate(quesid2ans) * 100.0,
            )

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (
                    epoch,
                    valid_score * 100.0,
                ) + "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.0)

            print(log_str, end="")

            with open(self.output + "/log.log", "a") as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")

    def predict(self, eval_tuple: DataTuple, dump=None):
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent = datum_tuple[:4]  # avoid handling target
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(feats, boxes, sent)
                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid] = ans
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        dset, loader, evaluator = eval_tuple
        quesid2ans = self.predict(eval_tuple, dump)
        return evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        torch.save(self.model.state_dict(), os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        for key in list(state_dict.keys()):
            if ".module" in key:
                state_dict[key.replace(".module", "")] = state_dict.pop(key)
        self.model.load_state_dict(state_dict, strict=False)
Exemplo n.º 11
0
class Classifier:
    def __init__(self):

        if args.train_json != "-1":
            self.train_tuple = get_tuple(
                args.train_json, bs=args.batch_size, shuffle=True, drop_last=True
            )
        else:
            self.train_tuple = None
        
        if args.valid_json != "-1":
            valid_bsize = 2048 if args.multiGPU else 512
            self.valid_tuple = get_tuple(
                args.valid_json, bs=valid_bsize,
                shuffle=False, drop_last=False
            )
        else:
            self.valid_tuple = None

        n_answers = len(json.load(open(args.ans2label)))

        self.model = ClassifierModel(n_answers, model_type=args.model_type)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Losses and optimizer, only if training
        if args.train_json != "-1":
            self.bce_loss = nn.BCEWithLogitsLoss()
            self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1)
            if 'bert' in args.optim:
                batch_per_epoch = len(self.train_tuple.loader)
                t_total = int(batch_per_epoch * args.epochs)
                print("Total Iters: %d" % t_total)
                from lxrt.optimization import BertAdam
                self.optim = BertAdam(list(self.model.parameters()),
                                      lr=args.lr,
                                      warmup=0.1,
                                      t_total=t_total)
            else:
                self.optim = args.optimizer(list(self.model.parameters()), args.lr)

        self.output = args.output_dir
        self.best_name = None
        
        os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))) if args.tqdm else (lambda x: x)

        best_valid = 0.
        for epoch in range(args.epochs):
            instance_id2pred = {}
            for i, (instance_ids, feats, boxes, sent, logit_in, target) in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                # for gradient checking
                # feats.requires_grad = True
                
                feats, boxes, logit_in, target = feats.cuda(), boxes.cuda(), logit_in.cuda(), target.cuda()
                logit = self.model(feats, boxes, sent) + logit_in

                # for gradient checks --- this errors for concat model (no dependence)
                # but gives a gradient for the full model
                # # text feature gradient
                # dldf = torch.autograd.grad(
                #     torch.sum(logit),
                #     self.model.lxrt_encoder.model.bert.tmp_cur_embedding_output,
                #     create_graph=True)[0]
                # # image feature double gradient
                # print(torch.autograd.grad(torch.sum(dldf), feats)[0])

                if target.dim() == 1: #expand targets in binary mode
                    assert logit.size(1) == 1
                    target = target.unsqueeze(1)
                assert logit.dim() == target.dim() == 2

                if logit.size(1) > 1: # multiclass, mce loss
                    max_value, target = target.max(1)
                    loss = self.mce_loss(logit, target) * logit.size(1)
                else: # binary
                    loss = self.bce_loss(logit, target)
                    loss = loss * logit.size(1)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                if logit.size()[1] > 1:
                    score, label = logit.max(1)
                else:
                    score = logit.flatten()
                    label = (logit > 0).float().flatten()
                
                for instance_id, l, scores in zip(instance_ids, label.cpu().numpy(), logit.detach().cpu().numpy()):
                    ans = dset.label2ans[l]
                    instance_id2pred[instance_id] = {'answer': ans, 'label': l, 'scores': scores}

            log_str = "\nEpoch %d: Train %0.2f\n" % (epoch,
                                                     evaluator.evaluate(instance_id2pred) * 100.)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.best_name = 'epoch_{}_valscore_{:.5f}_argshash_{}'.format(epoch, valid_score * 100., args.args_hash)
                    self.save(self.best_name)

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")

    def predict(self, eval_tuple: DataTuple, dump=None):
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        instance_id2pred = {}
        for i, datum_tuple in tqdm(enumerate(loader), total=len(loader)):
            instance_ids, feats, boxes, sent, logit_in = datum_tuple[:5]   # avoid handling target
            with torch.no_grad():
                feats, boxes, logit_in = feats.cuda(), boxes.cuda(), logit_in.cuda()
                
                logit = self.model(feats, boxes, sent) + logit_in
                if logit.size()[1] > 1:
                    score, label = logit.max(1)
                else:
                    score = logit.flatten()
                    label = (logit > 0).float().flatten()
                for instance_id, l, scores in zip(instance_ids, label.cpu().numpy(), logit.detach().cpu().numpy()):
                    ans = dset.label2ans[l]
                    instance_id2pred[instance_id] = {'answer': ans, 'label': l, 'scores': scores}
        if dump is not None:
            evaluator.dump_result(instance_id2pred, dump)
        return instance_id2pred

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        dset, loader, evaluator = eval_tuple
        instance_id2pred = self.predict(eval_tuple, dump)
        return evaluator.evaluate(instance_id2pred)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        instance_id2pred = {}
        for i, (instance_ids, feats, boxes, sent, logit_in, target) in enumerate(loader):
            if len(target.size()) > 1 and target.size()[1] > 1:
                _, label = target.max(1)
            else:
                label = torch.flatten(target)
            for instance_id, l in zip(instance_ids, label.cpu().numpy()):
                ans = dset.label2ans[l]
                instance_id2pred[instance_id] = {'answer': ans, 'label': l}
        return evaluator.evaluate(instance_id2pred)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        for key in list(state_dict.keys()):
            if '.module' in key:
                state_dict[key.replace('.module', '')] = state_dict.pop(key)
        self.model.load_state_dict(state_dict, strict=False)
Exemplo n.º 12
0
class VQA:
    def __init__(self):
        # Datasets
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True)
        if args.valid != "":
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=1024,
                                              shuffle=False,
                                              drop_last=False)
        else:
            self.valid_tuple = None

        # Model
        self.model = VQAModel(self.train_tuple.dataset.num_answers)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        best_valid = 0.
        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, (ques_id, feats, boxes, sent,
                    target) in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda(
                )
                logit = self.model(feats, boxes, sent)
                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans

            log_str = "\nEpoch %d: Train %0.2f\n" % (
                epoch, evaluator.evaluate(quesid2ans) * 100.)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")

    def predict(self, eval_tuple: DataTuple, dump=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple

        question_id2img_id = {x["question_id"]: x["img_id"] for x in dset.data}
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased",
                                                  do_lower_case=True)
        plt.rcParams['figure.figsize'] = (12, 10)
        num_regions = 36

        count = 0

        quesid2ans = {}
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent = datum_tuple[:
                                                      4]  # Avoid seeing ground truth
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(feats, boxes, sent)

                for layer in [0, 4]:
                    for head in [0, 1]:
                        for datapoint in range(len(sent)):
                            print(count, len(sent))
                            count += 1
                            lang2vis_attention_probs = self.model.lxrt_encoder.model.bert.encoder.x_layers[
                                layer].lang_att_map[datapoint][head].detach(
                                ).cpu().numpy()

                            vis2lang_attention_probs = self.model.lxrt_encoder.model.bert.encoder.x_layers[
                                layer].visn_att_map[datapoint][head].detach(
                                ).cpu().numpy()

                            plt.clf()

                            plt.subplot(2, 3, 1)
                            plt.gca().set_axis_off()
                            plt.title("Image (regions 0-7)")
                            im = cv2.imread(
                                os.path.join(
                                    "/mnt/8tera/claudio.greco/mscoco_trainval_2014",
                                    question_id2img_id[
                                        ques_id[datapoint].item()]) + ".jpg")
                            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                            plt.imshow(im)

                            plt.subplot(2, 3, 2)
                            plt.gca().set_axis_off()
                            plt.title("Image (regions 8-15)")
                            im = cv2.imread(
                                os.path.join(
                                    "/mnt/8tera/claudio.greco/mscoco_trainval_2014",
                                    question_id2img_id[
                                        ques_id[datapoint].item()]) + ".jpg")
                            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                            plt.imshow(im)

                            plt.subplot(2, 3, 3)
                            plt.gca().set_axis_off()
                            plt.title("Image (regions 16-35)")
                            im = cv2.imread(
                                os.path.join(
                                    "/mnt/8tera/claudio.greco/mscoco_trainval_2014",
                                    question_id2img_id[
                                        ques_id[datapoint].item()]) + ".jpg")
                            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                            plt.imshow(im)

                            img_info = loader.dataset.imgid2img[
                                question_id2img_id[ques_id[datapoint].item()]]
                            img_h, img_w = img_info['img_h'], img_info['img_w']
                            unnormalized_boxes = boxes[datapoint].clone()
                            unnormalized_boxes[:, (0, 2)] *= img_w
                            unnormalized_boxes[:, (1, 3)] *= img_h

                            for i, bbox in enumerate(unnormalized_boxes):
                                if i < 8:
                                    plt.subplot(2, 3, 1)
                                elif i < 16:
                                    plt.subplot(2, 3, 2)
                                else:
                                    plt.subplot(2, 3, 3)

                                bbox = [
                                    bbox[0].item(), bbox[1].item(),
                                    bbox[2].item(), bbox[3].item()
                                ]

                                if bbox[0] == 0:
                                    bbox[0] = 2
                                if bbox[1] == 0:
                                    bbox[1] = 2

                                plt.gca().add_patch(
                                    plt.Rectangle((bbox[0], bbox[1]),
                                                  bbox[2] - bbox[0] - 4,
                                                  bbox[3] - bbox[1] - 4,
                                                  fill=False,
                                                  edgecolor='red',
                                                  linewidth=1))

                                plt.gca().text(bbox[0],
                                               bbox[1] - 2,
                                               '%s' % i,
                                               bbox=dict(facecolor='blue'),
                                               fontsize=9,
                                               color='white')

                            ax = plt.subplot(2, 1, 2)
                            plt.title("Cross-modal attention lang2vis")

                            tokenized_question = tokenizer.tokenize(
                                sent[datapoint])
                            tokenized_question = [
                                "<CLS>"
                            ] + tokenized_question + ["<SEP>"]

                            transposed_attention_map = lang2vis_attention_probs[:len(
                                tokenized_question), :num_regions]
                            im = plt.imshow(transposed_attention_map,
                                            vmin=0,
                                            vmax=1)

                            for i in range(len(tokenized_question)):
                                for j in range(num_regions):
                                    att_value = round(
                                        transposed_attention_map[i, j], 1)
                                    text = ax.text(
                                        j,
                                        i,
                                        att_value,
                                        ha="center",
                                        va="center",
                                        color="w" if att_value <= 0.5 else "b",
                                        fontsize=6)

                            ax.set_xticks(np.arange(num_regions))
                            ax.set_xticklabels(list(range(num_regions)))

                            ax.set_yticks(np.arange(len(tokenized_question)))
                            ax.set_yticklabels(tokenized_question)

                            plt.tight_layout()
                            # plt.gca().set_axis_off()
                            plt.savefig(
                                "/mnt/8tera/claudio.greco/guesswhat_lxmert/guesswhat/visualization_vqa/lang2vis_question_{}_layer_{}_head_{}.png"
                                .format(ques_id[datapoint].item(), layer,
                                        head),
                                bbox_inches='tight',
                                pad_inches=0.5)

                            plt.close()

                            ## vis2lang

                            plt.clf()

                            plt.subplot(2, 3, 1)
                            plt.gca().set_axis_off()
                            plt.title("Image (regions 0-7)")
                            im = cv2.imread(
                                os.path.join(
                                    "/mnt/8tera/claudio.greco/mscoco_trainval_2014",
                                    question_id2img_id[
                                        ques_id[datapoint].item()]) + ".jpg")
                            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                            plt.imshow(im)

                            plt.subplot(2, 3, 2)
                            plt.gca().set_axis_off()
                            plt.title("Image (regions 8-15)")
                            im = cv2.imread(
                                os.path.join(
                                    "/mnt/8tera/claudio.greco/mscoco_trainval_2014",
                                    question_id2img_id[
                                        ques_id[datapoint].item()]) + ".jpg")
                            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                            plt.imshow(im)

                            plt.subplot(2, 3, 3)
                            plt.gca().set_axis_off()
                            plt.title("Image (regions 16-35)")
                            im = cv2.imread(
                                os.path.join(
                                    "/mnt/8tera/claudio.greco/mscoco_trainval_2014",
                                    question_id2img_id[
                                        ques_id[datapoint].item()]) + ".jpg")
                            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                            plt.imshow(im)

                            img_info = loader.dataset.imgid2img[
                                question_id2img_id[ques_id[datapoint].item()]]
                            img_h, img_w = img_info['img_h'], img_info['img_w']
                            unnormalized_boxes = boxes[datapoint].clone()
                            unnormalized_boxes[:, (0, 2)] *= img_w
                            unnormalized_boxes[:, (1, 3)] *= img_h

                            for i, bbox in enumerate(unnormalized_boxes):
                                if i < 8:
                                    plt.subplot(2, 3, 1)
                                elif i < 16:
                                    plt.subplot(2, 3, 2)
                                else:
                                    plt.subplot(2, 3, 3)

                                bbox = [
                                    bbox[0].item(), bbox[1].item(),
                                    bbox[2].item(), bbox[3].item()
                                ]

                                if bbox[0] == 0:
                                    bbox[0] = 2
                                if bbox[1] == 0:
                                    bbox[1] = 2

                                plt.gca().add_patch(
                                    plt.Rectangle((bbox[0], bbox[1]),
                                                  bbox[2] - bbox[0] - 4,
                                                  bbox[3] - bbox[1] - 4,
                                                  fill=False,
                                                  edgecolor='red',
                                                  linewidth=1))

                                plt.gca().text(bbox[0],
                                               bbox[1] - 2,
                                               '%s' % i,
                                               bbox=dict(facecolor='blue'),
                                               fontsize=9,
                                               color='white')

                            ax = plt.subplot(2, 1, 2)
                            plt.title("Cross-modal attention vis2lang")

                            tokenized_question = tokenizer.tokenize(
                                sent[datapoint])
                            tokenized_question = [
                                "<CLS>"
                            ] + tokenized_question + ["<SEP>"]

                            transposed_attention_map = vis2lang_attention_probs.transpose(
                            )[:len(tokenized_question), :num_regions]
                            im = plt.imshow(transposed_attention_map,
                                            vmin=0,
                                            vmax=1)

                            for i in range(len(tokenized_question)):
                                for j in range(num_regions):
                                    att_value = round(
                                        transposed_attention_map[i, j], 1)
                                    text = ax.text(
                                        j,
                                        i,
                                        att_value,
                                        ha="center",
                                        va="center",
                                        color="w" if att_value <= 0.5 else "b",
                                        fontsize=6)

                            ax.set_xticks(np.arange(num_regions))
                            ax.set_xticklabels(list(range(num_regions)))

                            ax.set_yticks(np.arange(len(tokenized_question)))
                            ax.set_yticklabels(tokenized_question)

                            plt.tight_layout()
                            # plt.gca().set_axis_off()
                            plt.savefig(
                                "/mnt/8tera/claudio.greco/guesswhat_lxmert/guesswhat/visualization_vqa/vis2lang_question_{}_layer_{}_head_{}.png"
                                .format(ques_id[datapoint].item(), layer,
                                        head),
                                bbox_inches='tight',
                                pad_inches=0.5)

                            plt.close()

                            # print(datapoint, len(sent))
                    #
                    #         print(datapoint)
                    #         if datapoint > 20:
                    #             break
                    #     if datapoint > 20:
                    #         break
                    # if datapoint > 20:
                    #     break

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        quesid2ans = self.predict(eval_tuple, dump)
        return eval_tuple.evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)
Exemplo n.º 13
0
class VQA:
    def __init__(self):
        # Datasets
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True)
        if args.valid != "":
            valid_bsize = args.get("valid_batch_size", 16)
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=valid_bsize,
                                              shuffle=False,
                                              drop_last=False)
        else:
            self.valid_tuple = None

        # Model
        self.model = VQAModel(self.train_tuple.dataset.num_answers)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.get("load_lxmert_pretrain", None) is not None:
            load_lxmert_from_pretrain_noqa(args.load_lxmert_pretrain,
                                           self.model)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()
            self.model.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        best_valid = 0.
        train_results = []
        report_every = args.get("report_every", 100)
        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, batch in iter_wrapper(enumerate(loader)):
                ques_id, feats, boxes, sent, tags, target = zip(*batch)
                self.model.train()
                self.optim.zero_grad()

                target = torch.stack(target).cuda()
                logit = self.model(feats, boxes, sent, tags)
                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()
                train_results.append(
                    pd.Series({"loss": loss.detach().mean().item()}))

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid] = ans

                if i % report_every == 0 and i > 0:
                    print("Epoch: {}, Iter: {}/{}".format(
                        epoch, i, len(loader)))
                    print("    {}\n~~~~~~~~~~~~~~~~~~\n".format(
                        pd.DataFrame(train_results[-report_every:]).mean()))

            log_str = "\nEpoch %d: Train %0.2f\n" % (
                epoch, evaluator.evaluate(quesid2ans) * 100.)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                if valid_score > best_valid and not args.get(
                        "special_test", False):
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)
            if epoch >= 5:
                self.save("Epoch{}".format(epoch))
            print(log_str, end='')
            print(args.output)

        self.save("LAST")

    def predict(self, eval_tuple: DataTuple, dump=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        for i, batch in enumerate(tqdm(loader)):
            _ = list(zip(*batch))
            ques_id, feats, boxes, sent, tags = _[:5]  #, target = zip(*batch)
            with torch.no_grad():
                #target = torch.stack(target).cuda()
                logit = self.model(feats, boxes, sent, tags)
                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid] = ans
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        quesid2ans = self.predict(eval_tuple, dump)
        return eval_tuple.evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)
class VQA:
    def __init__(self):
        # Datasets
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True)
        if args.valid != "":
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=1024,
                                              shuffle=False,
                                              drop_last=False)
        else:
            self.valid_tuple = None

        # Model
        self.model = VQAModel(self.train_tuple.dataset.num_answers,
                              finetune_strategy=args.finetune_strategy)

        # if finetune strategy is spottune
        if args.finetune_strategy in PolicyStrategies:
            self.policy_model = PolicyLXRT(
                PolicyStrategies[args.finetune_strategy])

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.finetune_strategy in PolicyStrategies:
            self.policy_model = self.policy_model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()
            self.policy_model.policy_lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Optimizer for policy net
        if args.finetune_strategy in PolicyStrategies:
            self.policy_optim = args.policy_optimizer(
                self.policy_model.parameters(), args.policy_lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple, visualizer=None):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        wandb.watch(self.model, log='all')
        if args.finetune_strategy in PolicyStrategies:
            wandb.watch(self.policy_model, log='all')

        best_valid = 0.

        for epoch in range(args.epochs):
            # for policy vec plotting
            if args.finetune_strategy in PolicyStrategies:
                policy_save = torch.zeros(
                    PolicyStrategies[args.finetune_strategy] // 2).cpu()
                policy_max = 0

            quesid2ans = {}
            for i, (ques_id, feats, boxes, sent,
                    target) in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                if args.finetune_strategy in PolicyStrategies:
                    self.policy_model.train()
                    self.policy_optim.zero_grad()

                feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda(
                )

                if args.finetune_strategy in PolicyStrategies:
                    # calculate the policy vector here
                    policy_vec = self.policy_model(feats, boxes, sent)
                    policy_action = gumbel_softmax(
                        policy_vec.view(policy_vec.size(0), -1, 2))
                    policy = policy_action[:, :, 1]
                    policy_save = policy_save + policy.clone().detach().cpu(
                    ).sum(0)
                    policy_max += policy.size(0)
                    logit = self.model(feats, boxes, sent, policy)
                else:
                    logit = self.model(feats, boxes, sent)

                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()
                if args.finetune_strategy in PolicyStrategies:
                    self.policy_optim.step()

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans

            # check if visualizer is not none
            if visualizer is not None:
                print(f'Creating training visualizations for epoch {epoch}')
                visualizer.plot(policy_save,
                                policy_max,
                                epoch=epoch,
                                mode='train')

            train_acc = evaluator.evaluate(quesid2ans) * 100.
            log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, train_acc)

            wandb.log({'Training Accuracy': train_acc})

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple,
                                            epoch=epoch,
                                            visualizer=visualizer)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)

                wandb.log({'Validation Accuracy': valid_score * 100.})

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")

    def predict(self,
                eval_tuple: DataTuple,
                dump=None,
                epoch=0,
                visualizer=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        if args.finetune_strategy in PolicyStrategies:
            self.policy_model.eval()
            policy_save = torch.zeros(
                PolicyStrategies[args.finetune_strategy] // 2)
            policy_max = 0

        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent = datum_tuple[:
                                                      4]  # Avoid seeing ground truth
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                if args.finetune_strategy in PolicyStrategies:
                    # calculate the policy vector here
                    policy_vec = self.policy_model(feats, boxes, sent)
                    policy_action = gumbel_softmax(
                        policy_vec.view(policy_vec.size(0), -1, 2))
                    policy = policy_action[:, :, 1]
                    policy_save = policy_save + policy.clone().detach().cpu(
                    ).sum(0)
                    policy_max += policy.size(0)
                    logit = self.model(feats, boxes, sent, policy)
                else:
                    logit = self.model(feats, boxes, sent)

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans

        if visualizer is not None:
            print(f'Creating validation visualization for epoch {epoch}...')
            visualizer.plot(policy_save, policy_max, epoch=epoch, mode='val')

        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def evaluate(self,
                 eval_tuple: DataTuple,
                 dump=None,
                 epoch=0,
                 visualizer=None):
        """Evaluate all data in data_tuple."""
        quesid2ans = self.predict(eval_tuple,
                                  dump,
                                  epoch=epoch,
                                  visualizer=visualizer)
        return eval_tuple.evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)
Exemplo n.º 15
0
class NLVR2:
    def __init__(self):
        self.train_tuple = get_tuple(args.train,
                                     bs=args.batch_size,
                                     shuffle=True,
                                     drop_last=True)
        if args.valid != "":
            valid_bsize = 2048 if args.multiGPU else 512
            self.valid_tuple = get_tuple(args.valid,
                                         bs=valid_bsize,
                                         shuffle=False,
                                         drop_last=False)
        else:
            self.valid_tuple = None

        self.model = NLVR2Model()

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)

        # GPU options
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()
        self.model = self.model.cuda()

        # Losses and optimizer
        self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1)
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(list(self.model.parameters()), args.lr)

        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        best_valid = 0.
        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, (ques_id, feats, boxes, sent,
                    label) in iter_wrapper(enumerate(loader)):
                self.model.train()

                self.optim.zero_grad()
                feats, boxes, label = feats.cuda(), boxes.cuda(), label.cuda()
                logit = self.model(feats, boxes, sent)

                loss = self.mce_loss(logit, label)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                score, predict = logit.max(1)
                for qid, l in zip(ques_id, predict.cpu().numpy()):
                    quesid2ans[qid] = l

            log_str = "\nEpoch %d: Train %0.2f\n" % (
                epoch, evaluator.evaluate(quesid2ans) * 100.)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")

    def predict(self, eval_tuple: DataTuple, dump=None):
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent = datum_tuple[:
                                                      4]  # avoid handling target
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(feats, boxes, sent)
                score, predict = logit.max(1)
                for qid, l in zip(ques_id, predict.cpu().numpy()):
                    quesid2ans[qid] = l
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        dset, loader, evaluator = eval_tuple
        quesid2ans = self.predict(eval_tuple, dump)
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)
Exemplo n.º 16
0
class VCSD:
    def __init__(self):
        # Datasets
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True,
                                          resize_img=args.resize_img)
        if args.valid != "":
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=1024,
                                              shuffle=False,
                                              drop_last=False,
                                              resize_img=args.resize_img)
        else:
            self.valid_tuple = None

        # Model
        self.model = VCSDModel()

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        # if args.load_lxmert_qa is not None:
        #     load_lxmert_qa(args.load_lxmert_qa, self.model,
        #                    label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        best_valid = 0.
        for epoch in range(args.epochs):
            datumid2pred = {}
            for i, (datum_id, raw_image_id, image_id, utterance, response, img,
                    bboxes, target) in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                img, target = img.cuda(), target.cuda()
                bboxes = bboxes.cuda()

                logit = self.model(utterance, response, img, bboxes)
                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                score, label = logit.max(1)
                for did, l in zip(datum_id, label.cpu().numpy()):
                    datumid2pred[did.item()] = l
            tp, tn, fp, fn = evaluator.evaluate(datumid2pred)
            accu = (tp + tn) / (tp + fp + fn + tn)
            log_str = "\nEpoch %d: Train accuracy %0.2f\n" % (epoch,
                                                              accu * 100.)

            if self.valid_tuple is not None:  # Do Validation
                valid_accu, valid_prec, valid_rec, valid_f1 = self.evaluate(
                    eval_tuple)
                if valid_accu > best_valid:
                    best_valid = valid_accu
                    self.save("BEST")

                log_str += "Epoch %d: Valid accuracy %0.2f precision %0.2f recall %0.2f F1 %0.2f\n" % \
                           (epoch, valid_accu * 100., valid_prec, valid_rec, valid_f1) + \
                           "Epoch %d: Best accuracy %0.2f\n" % (epoch, best_valid * 100.)

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")

    def predict(self, eval_tuple: DataTuple, dump=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        datumid2pred = {}
        for i, datum_tuple in enumerate(loader):
            datum_id, raw_image_id, image_id, utterance, response, img, bboxes = datum_tuple[:
                                                                                             7]  # avoid seeing ground truth
            with torch.no_grad():
                img = img.cuda()
                bboxes = bboxes.cuda()
                logit = self.model(utterance, response, img, bboxes)
                score, label = logit.max(1)
                for did, l in zip(datum_id, label.cpu().numpy()):
                    datumid2pred[did.item()] = l
        if dump is not None:
            evaluator.dump_output(datumid2pred, dump)
        return datumid2pred

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        # metric = F1(num_classes=2)
        #             average='macro',
        #             compute_on_step=False)

        datumid2preds = self.predict(eval_tuple, dump)
        # return eval_tuple.evaluator.evaluate(datumid2preds)
        tp, tn, fp, fn = eval_tuple.evaluator.evaluate(datumid2preds)

        accu = (tp + tn) / (tp + fp + fn + tn)
        precision = (tp / (tp + fp)) if (tp + fp) > 0 else 0
        recall = (tp / (tp + fn)) if (tp + fn) > 0 else 0
        fmeasure = 0
        if precision + recall > 0:
            fmeasure = 2 * precision * recall / (precision + recall)
        # F1 = metric.compu
        return accu, precision, recall, fmeasure

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        datumid2preds = {}
        for i, (datum_id, raw_image_id, image_id, utterance, response, img,
                bboxes, target) in enumerate(loader):
            _, label = target.max(1)
            for did, l in zip(datum_id, label.cpu().numpy()):
                datumid2preds[i] = l
        return evaluator.evaluate(datumid2preds)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)
Exemplo n.º 17
0
class Rank:
    def __init__(self):
        if args.train_json != "-1":
            self.train_tuple = get_tuple(args.train_json,
                                         bs=args.batch_size,
                                         shuffle=True,
                                         drop_last=True)
        else:
            self.train_tuple = None

        if args.valid_json != "-1":
            valid_bsize = 2048 if args.multiGPU else 512
            self.valid_tuple = get_tuple(args.valid_json,
                                         bs=valid_bsize,
                                         shuffle=False,
                                         drop_last=False)
        else:
            self.valid_tuple = None

        self.model = RankModel(model_type=args.model_type)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Losses and optimizer, only if training
        if args.train_json != "-1":
            self.rank_loss = nn.BCEWithLogitsLoss()
            if 'bert' in args.optim:
                batch_per_epoch = len(self.train_tuple.loader)
                t_total = int(batch_per_epoch * args.epochs)
                print("Total Iters: %d" % t_total)
                from lxrt.optimization import BertAdam
                self.optim = BertAdam(list(self.model.parameters()),
                                      lr=args.lr,
                                      warmup=0.1,
                                      t_total=t_total)
            else:
                self.optim = args.optimizer(list(self.model.parameters()),
                                            args.lr)

        self.output = args.output_dir
        self.best_name = None

        os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        best_valid = 0.
        for epoch in range(args.epochs):
            instance_id2pred = {}
            for i, (instance_ids, f0, b0, f1, b1, sent0, sent1, logit_in,
                    label) in iter_wrapper(enumerate(loader)):
                self.model.train()
                self.optim.zero_grad()
                f0, b0 = f0.cuda(), b0.cuda()
                f1, b1 = f1.cuda(), b1.cuda()

                label = label.cuda()
                logit_in = logit_in.cuda()

                score0 = self.model(f0, b0, sent0)
                score1 = self.model(f1, b1, sent1)
                logit = score0 - score1 + logit_in

                if label.dim() == 1:  #expand targets in binary mode
                    assert logit.size(1) == 1
                    label = label.unsqueeze(1)

                loss = self.rank_loss(logit, label)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                predict = (logit > 0).float()

                for instance_id, l, score, c_score_0, c_score_1 in zip(
                        instance_ids,
                        predict.cpu().numpy(),
                        logit.detach().cpu().numpy(),
                        score0.detach().cpu().numpy(),
                        score1.detach().cpu().numpy()):
                    instance_id2pred[instance_id] = {
                        'label': float(l),
                        'scores': score,
                        'score0': float(c_score_0[0]),
                        'score1': float(c_score_1[0])
                    }

            log_str = "\nEpoch %d: Train %0.2f\n" % (
                epoch, evaluator.evaluate(instance_id2pred) * 100.)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.best_name = 'epoch_{}_valscore_{:.5f}_argshash_{}'.format(
                        epoch, valid_score * 100., args.args_hash)
                    self.save(self.best_name)

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")

    def predict(self, eval_tuple: DataTuple, dump=None):
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        instance_id2pred = {}
        for i, datum_tuple in tqdm(enumerate(loader), total=len(loader)):
            instance_ids, f0, b0, f1, b1, sent0, sent1, logit_in = datum_tuple[:
                                                                               -1]
            with torch.no_grad():
                f0, b0 = f0.cuda(), b0.cuda()
                f1, b1 = f1.cuda(), b1.cuda()

                score0 = self.model(f0, b0, sent0)
                score1 = self.model(f1, b1, sent1)

                logit_in = logit_in.cuda()

                logit = score0 - score1 + logit_in
                predict = (logit > 0).float()

                for instance_id, l, score, c_score_0, c_score_1 in zip(
                        instance_ids,
                        predict.cpu().numpy(),
                        logit.detach().cpu().numpy(),
                        score0.cpu().numpy(),
                        score1.cpu().numpy()):
                    instance_id2pred[instance_id] = {
                        'label': float(l),
                        'scores': score,
                        'score0': float(c_score_0[0]),
                        'score1': float(c_score_1[0])
                    }

        if dump is not None:
            evaluator.dump_result(instance_id2pred, dump)

        return instance_id2pred

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        dset, loader, evaluator = eval_tuple
        instance_id2pred = self.predict(eval_tuple, dump)
        return evaluator.evaluate(instance_id2pred)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        instance_id2pred = {}
        for i, (instance_ids, f0, b0, f1, b1, sent0, sent1, logit_in,
                label) in enumerate(loader):
            for instance_id, l in zip(instance_ids, label.cpu().numpy()):
                instance_id2pred[instance_id] = {'label': float(l)}
        return evaluator.evaluate(instance_id2pred)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)
class GQA:
    def __init__(self):
        self.train_tuple = get_tuple(args.train,
                                     bs=args.batch_size,
                                     shuffle=True,
                                     drop_last=True)
        if args.valid != "":
            valid_bsize = 2048 if args.multiGPU else 512
            self.valid_tuple = get_tuple(args.valid,
                                         bs=valid_bsize,
                                         shuffle=False,
                                         drop_last=False)
        else:
            self.valid_tuple = None

        self.model = GQAModel(self.train_tuple.dataset.num_answers)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            self.new_ans_label = load_lxmert_qa(
                args.load_lxmert_qa,
                self.model,
                label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Losses and optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1)
        # self.KL_loss = nn.KLDivLoss(reduction='none')
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(list(self.model.parameters()), args.lr)

        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

        # Tensorboard
        self.boards_dir = os.path.join('boards', self.output)
        if not os.path.exists(self.boards_dir):
            os.makedirs(self.boards_dir)
        self.writerTbrd = SummaryWriter(self.boards_dir)

        # get Glove projection for all answers
        if args.answer_loss == 'glove':
            path_glove = './data/GloVe/GloVeDict.pkl'
            with open(path_glove, 'rb') as f:
                glove_dic = pickle.load(f)
            glove_dim = glove_dic['the'].shape[-1]
            print("Loading Glove%d answer's vector" % glove_dim)
            self.labelans2glove = []
            self.valid_ans_embed = [1] * len(
                self.train_tuple.dataset.label2ans)
            for label, ans in enumerate(self.train_tuple.dataset.label2ans):
                ans = ans.split(' ')
                glove_ans = []
                for w in ans:
                    #print(w)
                    try:
                        glove_ans.append(glove_dic[w])
                    except KeyError:
                        #print('Full ans: %s' % ans)
                        #input(' ')
                        self.valid_ans_embed[label] = 0
                        glove_ans.append(np.zeros(glove_dim))
                #print(glove_ans)
                glove_ans = torch.tensor(glove_ans).mean(-2)
                self.labelans2glove.append(torch.tensor(glove_ans))
            #print(self.labelans2glove)
            print(
                'Ratio of valid ans embedding: %f' %
                (float(sum(self.valid_ans_embed)) / len(self.valid_ans_embed)))
            self.labelans2glove = torch.stack(
                self.labelans2glove).float().cuda()
            self.cosineSim = torch.nn.CosineSimilarity(dim=1, eps=1e-08)

# ? DEBUG CORENTIN ****************************************************************

    def check_pointer_manually(self, train_tuple):
        IMAGE_PATH = 'data/gqa/images'
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)
        for i, (ques_id, feats, boxes, sent, target, iou_question, iou_answer)\
                 in iter_wrapper(enumerate(loader)):

            for batch_index in range(len(ques_id)):
                datum = dset.id2datum[ques_id[batch_index]]
                # Load image
                im_id = datum['image_id']
                im_path = os.path.join(IMAGE_PATH, '%s.jpg' % im_id)
                image_pil = Image.open(im_path)
                im = np.array(image_pil, dtype=np.uint8)
                height = image_pil.height
                width = image_pil.width
                # Load annotations
                question = datum['sent']
                question_pointer = datum['pointer']['question']
                answer = list(datum['label'])[0]
                answer_pointer = datum['pointer']['answer']
                # * * Display pointer and bboxes
                # Create plot
                fig = plt.figure()
                plt.suptitle(question)
                plt.title(answer)
                ax = fig.add_subplot(1, 1, 1)
                # draw image
                ax.imshow(im)

                # draw detected boxes
                def iou_preprocess(iou):
                    TRESHOLD = 0.1
                    TOPK = 5
                    # norm_iou = np.exp(iou) / np.sum(np.exp(iou), axis=0)  #iou / (iou.sum() + 1e-9)
                    # f_iou = norm_iou * (iou.sum() >= TRESHOLD)
                    sorted_idx = np.argsort(iou)[::-1]
                    iou_topk = iou
                    iou_topk[sorted_idx[TOPK:]] = -1e9
                    f_iou = np.exp(iou_topk) / np.sum(
                        np.exp(iou_topk), axis=0)  #iou / (iou.sum() + 1e-9)
                    f_iou = f_iou * (iou_topk.clip(min=0).sum() >= TRESHOLD)
                    return f_iou

                detected_bboxe = boxes[batch_index] * torch.tensor(
                    [width, height, width, height]).float()
                total_iou_per_object = np.zeros((boxes.size(1)))
                for _, pointer in question_pointer.items():
                    iou = np.array(pointer['iou'])
                    f_iou = iou_preprocess(iou)
                    total_iou_per_object += f_iou
                for _, pointer in answer_pointer.items():
                    iou = np.array(pointer['iou'])
                    f_iou = iou_preprocess(iou)
                    total_iou_per_object += f_iou
                intensity = total_iou_per_object.clip(min=0, max=1)
                c = [np.array([0, 0, 1]) for j in range(boxes.size(1))]
                draw_bboxes(detected_bboxe.numpy(),
                            ax,
                            color=c,
                            alpha=intensity)
                # draw pointer boxes
                for word_id, pointer in question_pointer.items():
                    bboxe = pointer['boxe']
                    bboxe = [
                        bboxe[0] * width, bboxe[1] * height, bboxe[2] * width,
                        bboxe[3] * height
                    ]
                    c = [np.array([1, 0, 0])]
                    draw_bboxes([bboxe], ax, color=c, label=['q_%s' % word_id])
                for word_id, pointer in answer_pointer.items():
                    bboxe = pointer['boxe']
                    bboxe = [
                        bboxe[0] * width, bboxe[1] * height, bboxe[2] * width,
                        bboxe[3] * height
                    ]
                    c = [np.array([0, 1, 0])]
                    draw_bboxes([bboxe], ax, color=c, label=['a_%s' % word_id])
                plt.savefig('check_pointer_%s_sftmx.jpg' % im_id)
                plt.close()
                # * * Retrieve statistics

                # input('Press ENTER for next image')
    def pointer_stats(self, train_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        # Stat watchers
        pointer_per_question = []
        nb_words_per_question = []
        pointed_words = {}
        max_iou = []
        top5_cumiou = []
        top10_cumiou = []
        total_cumiou = []

        start = time.time()
        for i, (ques_id, feats, boxes, sent, target, iou_question, iou_answer)\
                 in iter_wrapper(enumerate(loader)):

            for batch_index in range(len(ques_id)):
                datum = dset.id2datum[ques_id[batch_index]]

                # Load annotations
                question = datum['sent']
                question_pointer = datum['pointer']['question']
                answer = list(datum['label'])[0]
                answer_pointer = datum['pointer']['answer']
                # parse question and answer
                parsed_question = question.translate(
                    str.maketrans('', '', string.punctuation)).split(' ')
                parsed_answer = answer.translate(
                    str.maketrans('', '', string.punctuation)).split(' ')
                # Stats words
                pointer_per_question.append(
                    len(question_pointer) + len(answer_pointer))
                nb_words_per_question.append(
                    len(parsed_question) + len(parsed_answer))

                def add2dic(pointer, parsed_sent):
                    for w_idx in pointer:
                        if ':' in w_idx:
                            indexes = w_idx.split(':')
                            for j in range(int(indexes[0]), int(indexes[1])):
                                word = parsed_sent[j]
                                if word in pointed_words:
                                    pointed_words[word] += 1
                                else:
                                    pointed_words[word] = 1
                        else:
                            word = parsed_sent[int(w_idx)]
                            if word in pointed_words:
                                pointed_words[word] += 1
                            else:
                                pointed_words[word] = 1

                add2dic(question_pointer, parsed_question)
                add2dic(answer_pointer, parsed_answer)
                # Stats IoU
                # max
                max_iou_question = iou_question.max(-1)[0]
                max_iou += max_iou_question[
                    max_iou_question > 0].flatten().tolist()
                max_iou_answer = iou_answer.max(-1)[0]
                max_iou += max_iou_answer[
                    max_iou_answer > 0].flatten().tolist()
                # top5
                top5_cumiou_question = iou_question.topk(5)[0].sum(-1)
                top5_cumiou_answer = iou_answer.topk(5)[0].sum(-1)
                top5_cumiou += top5_cumiou_question[
                    top5_cumiou_question > 0].flatten().tolist()
                top5_cumiou += top5_cumiou_answer[
                    top5_cumiou_answer > 0].flatten().tolist()
                # top10
                top10_cumiou_question = iou_question.topk(10)[0].sum(-1)
                top10_cumiou_answer = iou_answer.topk(10)[0].sum(-1)
                top10_cumiou += top10_cumiou_question[
                    top10_cumiou_question > 0].flatten().tolist()
                top10_cumiou += top10_cumiou_answer[
                    top10_cumiou_answer > 0].flatten().tolist()
                # total
                total_cumiou_question = iou_question.sum(-1)
                total_cumiou_answer = iou_answer.sum(-1)
                total_cumiou += total_cumiou_question[
                    total_cumiou_question > 0].flatten().tolist()
                total_cumiou += total_cumiou_answer[
                    total_cumiou_answer > 0].flatten().tolist()

        elapsed_time = time.time() - start
        print("Time: %.1fmin" % (elapsed_time / 60))
        # Save into pickle dic
        dic = {
            'pointer_per_question': pointer_per_question,
            'nb_words_per_question': nb_words_per_question,
            'pointed_words': pointed_words,
            'max_iou': max_iou,
            'top5_cumiou': top5_cumiou,
            'top10_cumiou': top10_cumiou,
            'total_cumiou': total_cumiou
        }
        with open('stats_pointer.pickle', 'wb') as handle:
            pickle.dump(dic, handle, protocol=pickle.HIGHEST_PROTOCOL)

# ? DEBUG CORENTIN ****************************************************************

# ? Manual evaluation (for matching) **********************************************

    def eval_matching_manual(self, eval_tuple):
        IMAGE_PATH = 'data/gqa/images'
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent, _, iou_question, iou_answer = datum_tuple
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                iou_question, iou_answer = iou_question.cuda(
                ), iou_answer.cuda()
                logit, iou_target, iou_pred = self.model(
                    feats, boxes, sent, iou_question, iou_answer)
                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid] = ans

            for batch_index in range(len(ques_id)):
                # Retrieve info + prediction
                qid = ques_id[batch_index]
                datum = dset.id2datum[qid]
                question = datum['sent']
                answer_pred = quesid2ans[qid]
                # Load image
                im_id = datum['image_id']
                im_path = os.path.join(IMAGE_PATH, '%s.jpg' % im_id)
                image_pil = Image.open(im_path)
                im = np.array(image_pil, dtype=np.uint8)
                height = image_pil.height
                width = image_pil.width
                detected_bboxe = boxes[batch_index].cpu() * torch.tensor(
                    [width, height, width, height]).float()
                # Display iou prediction
                for w in range(iou_pred[batch_index].size(0)):
                    fig = plt.figure()
                    plt.suptitle('Q:%s __ Idx:%d' % (question, w))
                    plt.title(answer_pred)
                    # all predicted bboxes
                    ax = fig.add_subplot(2, 1, 1)
                    ax.imshow(im)
                    c = [np.array([0, 0, 1]) for j in range(boxes.size(1))]
                    draw_bboxes(detected_bboxe.numpy(), ax, color=c)
                    # matched bboxes
                    iou = iou_pred[batch_index, w]
                    iou_norm = iou / (iou.sum() + 1e-9)
                    ax = fig.add_subplot(2, 1, 2)
                    ax.imshow(im)
                    c = [np.array([0, 0, 1]) for j in range(boxes.size(1))]
                    draw_bboxes(detected_bboxe.numpy(),
                                ax,
                                color=c,
                                alpha=iou_norm)
                    plt.savefig('t0.4_pred_pointer_%s_%d.jpg' % (im_id, w))
                    plt.close()
                input('Press ENTER for next image')


# ? Manual evaluation (for matching) **********************************************

    def gqa_analysis(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)
        for i, (ques_id, feats, boxes, sent, target, iou_question, iou_answer, sem_question_words, sem_answer_words, bboxes_words,)\
             in iter_wrapper(enumerate(loader)):

            with torch.no_grad():
                feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda(
                )
                iou_question, iou_answer = iou_question.cuda(
                ), iou_answer.cuda()
                sem_question_words, sem_answer_words, bboxes_words = sem_question_words.cuda(
                ), sem_answer_words.cuda(), bboxes_words.cuda()
                logit, iou_target, iou_score, lang_feat, vis_feat, tkn_sent = self.model(
                    feats,
                    boxes,
                    sent,
                    iou_question,
                    iou_answer,
                    sem_question_words,
                    sem_answer_words,
                    bboxes_words,
                    verbose=True)
                for i in range(lang_feat.size(0)):
                    len_sent = len(tkn_sent[i])
                    self.writerTbrd.add_embedding(lang_feat[i, :len_sent],
                                                  metadata=tkn_sent[i])
                    pass

    def extract_maps(self, eval_tuple):
        self.model.eval()
        #self.model.cpu()
        dset, loader, evaluator = eval_tuple
        timer = time.time()
        att_maps = None
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent, target, iou_question, iou_answer, sem_question_words, sem_answer_words, bboxes_words = datum_tuple
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                iou_question, iou_answer = iou_question.cuda(
                ), iou_answer.cuda()
                sem_question_words, sem_answer_words, bboxes_words = sem_question_words.cuda(
                ), sem_answer_words.cuda(), bboxes_words.cuda()
                logit, iou_target, iou_score, activations = self.model(
                    feats, boxes, sent, iou_question, iou_answer,
                    sem_question_words, sem_answer_words, bboxes_words)
                score, label = logit.max(1)

            # init according to model's architecture
            activations['cross'] = [
                item for sublist in activations['cross'] for item in sublist
            ]  # flatten
            if att_maps is None:
                print('map shape', activations['lang'][0].shape)
                n_head = activations['lang'][0].shape[1]
                att_maps = {
                    'lang': [
                        torch.zeros((n_head))
                        for t in range(len(activations['lang']))
                    ],
                    'vis': [
                        torch.zeros((n_head))
                        for t in range(len(activations['vis']))
                    ],
                    'cross': [
                        torch.zeros((n_head))
                        for t in range(len(activations['cross']))
                    ]
                }
                map_names = []
            for maptype in ['lang', 'vis', 'cross']:
                for idx, maps in enumerate(activations[maptype]):
                    if maptype == 'cross':
                        sub_id = idx % 4
                        sub_maptype = ['xvl', 'xlv', 'xl', 'xv']
                        name = '%s%d' % (sub_maptype[sub_id], idx)
                    else:
                        name = '%s%d' % (maptype[0], idx)
                    map_names.append(name)
                    d_b, d_h, d_1, d_2 = maps.shape  # [batch, head, d1, d2]
                    head_max = torch.max(maps.view(d_b, d_h, d_1 * d_2),
                                         dim=-1).values  # [batch X heads]
                    head_max = head_max.sum(0)  # sum over batch: [heads]
                    att_maps[maptype][idx] += head_max.cpu().data

        all_max_maps = torch.cat([
            torch.stack(att_maps['lang']),
            torch.stack(att_maps['vis']),
            torch.stack(att_maps['cross'])
        ]).numpy()  # [layer, heads]
        all_max_maps = all_max_maps / len(dset)
        print("MAX ATT divided", all_max_maps)
        # Draw histogram
        FIG_PATH = self.output
        draw_histogram(all_max_maps, path=FIG_PATH, labels=map_names)
        input('_')

        print('Processes set in %ds' % (time.time() - timer))

    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        best_valid = 0.
        optim_steps = 0
        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, (ques_id, feats, boxes, sent, target, iou_question, iou_answer, sem_question_words, sem_answer_words, bboxes_words,)\
                 in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                # DEBUG: print pointer (set batch size to 1)
                # print(dset.id2datum[ques_id[0]]['sent'])
                # print(dset.id2datum[ques_id[0]]['label'])
                # q_pointer = dset.id2datum[ques_id[0]]['pointer']['question']
                # for w_index in q_pointer:
                #     print(w_index)

                feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda(
                )
                iou_question, iou_answer = iou_question.cuda(
                ), iou_answer.cuda()
                sem_question_words, sem_answer_words, bboxes_words = sem_question_words.cuda(
                ), sem_answer_words.cuda(), bboxes_words.cuda()
                logit, iou_target, iou_score = self.model(
                    feats, boxes, sent, iou_question, iou_answer,
                    sem_question_words, sem_answer_words, bboxes_words)
                assert logit.dim() == target.dim() == 2
                if args.mce_loss:
                    max_value, target = target.max(1)
                    loss = self.mce_loss(logit, target) * logit.size(1)
                else:
                    loss = self.bce_loss(logit, target)
                    loss = loss * logit.size(1)
                #print('CE', loss.item())

                if args.answer_loss == 'glove':
                    gold_glove = (self.labelans2glove.unsqueeze(0) *
                                  target.unsqueeze(-1)).sum(1)
                    #gold_ans = self.train_tuple.dataset.label2ans[target.argmax(dim=1)[0]]
                    #print('gold:', gold_ans)
                    pred_glove = (
                        self.labelans2glove.unsqueeze(0) *
                        torch.softmax(logit, dim=1).unsqueeze(-1)).sum(1)
                    #pred_ans = self.train_tuple.dataset.label2ans[logit.argmax(dim=1)[0]]
                    #print('pred:', pred_ans)
                    sim_answer = self.cosineSim(gold_glove, pred_glove).mean()
                    loss += -10 * sim_answer
                    #print('Similarity', sim_answer)
                    #input(' ')

                if optim_steps % 1000 == 0:
                    self.writerTbrd.add_scalar('vqa_loss_train', loss.item(),
                                               optim_steps)

                # task_pointer = 'KLDiv'
                ALPHA = args.alpha_pointer

                def iou_preprocess(iou, obj_conf=None):
                    TRESHOLD = 0.1
                    TOPK = 3
                    # norm_iou = np.exp(iou) / np.sum(np.exp(iou), axis=0)  #iou / (iou.sum() + 1e-9)
                    # f_iou = norm_iou * (iou.sum() >= TRESHOLD)
                    sorted_values = torch.sort(iou, descending=True, dim=-1)[0]
                    t_top = sorted_values[:, :, TOPK - 1]
                    iou_topk = iou.masked_fill(iou < t_top.unsqueeze(-1), -1e9)
                    f_iou = torch.softmax(iou_topk, dim=-1)
                    treshold_mask = (iou_topk.clamp(min=.0).sum(-1) >=
                                     TRESHOLD).float()
                    if args.task_pointer == 'KLDiv':
                        return f_iou, treshold_mask
                    elif args.task_pointer == 'Triplet':
                        # Remove top10 most similar objects
                        t_bot = sorted_values[:, :, 10]
                        iou_botk = (iou < t_bot.unsqueeze(-1)).float()
                        # Take topk most confident objects
                        conf_top = torch.sort(obj_conf.unsqueeze(1) * iou_botk,
                                              descending=True,
                                              dim=-1)[0][:, :, TOPK - 1]
                        conf_mask = obj_conf.unsqueeze(1).expand(
                            -1, iou.size(1), -1) >= conf_top.unsqueeze(-1)
                        neg_score = iou_botk * conf_mask.float()
                        return f_iou, treshold_mask, neg_score

                if args.task_pointer == 'KLDiv':
                    iou_target_preprocess, treshold_mask = iou_preprocess(
                        iou_target)
                    loss_pointer_fct = KLDivLoss(reduction='none')
                    iou_pred = torch.log_softmax(iou_score, dim=-1)
                    matching_loss = loss_pointer_fct(
                        input=iou_pred, target=iou_target_preprocess)
                    matching_loss = ALPHA * (matching_loss.sum(-1) *
                                             treshold_mask).sum() / (
                                                 (treshold_mask).sum() + 1e-9)
                    if optim_steps % 1000 == 0:
                        self.writerTbrd.add_scalar('pointer_loss_train',
                                                   matching_loss.item(),
                                                   optim_steps)
                    loss += matching_loss

                # ? by Corentin: Matching loss
                # def iou_preprocess(iou):
                #     TRESHOLD = 0.1
                #     TOPK = 1
                #     # norm_iou = np.exp(iou) / np.sum(np.exp(iou), axis=0)  #iou / (iou.sum() + 1e-9)
                #     # f_iou = norm_iou * (iou.sum() >= TRESHOLD)
                #     t = torch.sort(iou, descending=True, dim=-1)[0][:, :, TOPK-1]
                #     iou_topk = iou.masked_fill(iou < t.unsqueeze(-1), -1e9)
                #     f_iou = torch.softmax(iou_topk, dim=-1)
                #     treshold_mask = (iou_topk.clamp(min=.0).sum(-1) >= TRESHOLD).float()
                #     return f_iou, treshold_mask
                # # discard iou_target when total iou is under treshold
                # # it includes unsupervised datum
                # iou_target_preprocess, treshold_mask = iou_preprocess(iou_target)
                # iou_pred = torch.log_softmax(iou_pred, dim=-1)
                # # KL loss
                # matching_loss = []
                # matching_loss = self.KL_loss(input=iou_pred, target=iou_target_preprocess)
                # matching_loss = (matching_loss.sum(-1) * treshold_mask).sum() / treshold_mask.sum()
                # if optim_steps % 1000 == 0:
                #     self.writerTbrd.add_scalar('pointer_loss_train', matching_loss.item(), optim_steps)
                # ALPHA = 5.0
                # loss += ALPHA * matching_loss
                # ? **************************

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()
                optim_steps += 1

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid] = ans

                # if self.valid_tuple is not None and optim_steps % 1152 == 0:  # Do Validation
                #     valid_score = self.evaluate(eval_tuple)
                #     fastepoch = int(optim_steps / 1152)
                #     print("fastEpoch %d: Valid %0.2f\n" % (fastepoch, valid_score * 100.,))

            log_str = "\nEpoch %d: Train %0.2f\n" % (
                epoch, evaluator.evaluate(quesid2ans) * 100.)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                self.writerTbrd.add_scalar('vqa_acc_valid', valid_score, epoch)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")

    def predict(self, eval_tuple: DataTuple, dump=None, iou=False):
        self.model.eval()
        #self.model.cpu()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        quesid2iou = {}
        timer = time.time()
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent, target, iou_question, iou_answer, sem_question_words, sem_answer_words, bboxes_words = datum_tuple
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                iou_question, iou_answer = iou_question.cuda(
                ), iou_answer.cuda()
                sem_question_words, sem_answer_words, bboxes_words = sem_question_words.cuda(
                ), sem_answer_words.cuda(), bboxes_words.cuda()
                logit, iou_target, iou_score = self.model(
                    feats, boxes, sent, iou_question, iou_answer,
                    sem_question_words, sem_answer_words, bboxes_words)
                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid] = ans
                quesid2iou[qid] = None  #iou_pred
        print('Processes set in %ds' % (time.time() - timer))
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        if iou is True:
            return quesid2ans, quesid2iou
        else:
            return quesid2ans

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        dset, loader, evaluator = eval_tuple
        quesid2ans = self.predict(eval_tuple, dump)
        return evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, target, iou_question, iou_answer, sem_question_words, sem_answer_words, bboxes_words,)\
            in enumerate(loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        for key in list(state_dict.keys()):
            if '.module' in key:
                state_dict[key.replace('.module', '')] = state_dict.pop(key)
        self.model.load_state_dict(state_dict, strict=False)

    def finetune(self, train_tuple, eval_tuple):
        # log
        output_1 = os.path.join(self.output, 'finetune_1')
        os.makedirs(output_1, exist_ok=True)
        output_2 = os.path.join(self.output, 'finetune_2')
        os.makedirs(output_2, exist_ok=True)

        # Tensorboard
        boards_dir_1 = os.path.join(self.boards_dir, 'finetune_1')
        if not os.path.exists(boards_dir_1):
            os.makedirs(boards_dir_1)
        boards_dir_2 = os.path.join(self.boards_dir, 'finetune_2')
        if not os.path.exists(boards_dir_2):
            os.makedirs(boards_dir_2)

        # Params
        lr_1 = args.lr
        lr_2 = args.lr / 10
        epochs_1 = 4  #int(args.epochs / 3)
        epochs_2 = args.epochs - epochs_1

        # Step 0: evaluate pretraining
        if self.valid_tuple is not None:  # Do Validation
            valid_score = self.evaluate(eval_tuple)
            print("Before finetune: Valid %0.2f\n" % (valid_score * 100.))

        # Step 0.1: finetune new ans only
        # new_ans_params = []
        # for name, p in self.model.named_parameters():
        #     if "logit_fc.3" in name:
        #         for idx in range(p.size(0)):
        #             if idx in self.new_ans_label:
        #                 new_ans_params.append({'params': p[idx]})

        # args.epochs = epochs_0
        # from lxrt.optimization import BertAdam
        # self.optim = BertAdam(new_ans_params,
        #                       lr=lr_1,
        #                       warmup=0.0,
        #                       t_total=-1)
        # print('### Start finetuning new ans...')
        # self.train(train_tuple, eval_tuple)

        # First step, only updates answer head

        #self.optim = torch.optim.Adamax(list(self.model.parameters()), lr_1)
        #self.optim = torch.optim.SGD(list(self.model.parameters()), lr_1)
        args.epochs = epochs_1
        batch_per_epoch = len(self.train_tuple.loader)
        t_total = int(batch_per_epoch * epochs_1)
        print("Total Iters: %d" % t_total)
        from lxrt.optimization import BertAdam
        self.optim = BertAdam(
            list(self.model.parameters()),
            lr=lr_1,
            warmup=0.0,  #!0.034
            t_total=-1)
        # loaded_optim = torch.load("%s_LXRT.pth" % args.load_lxmert_qa)['optimizer']
        # self.optim.load_state_dict(loaded_optim)
        # for group in loaded_optim.param_groups:
        #     for p in group['params']:
        #         if p in loaded_optim['state']:
        #             self.optim.state[p] = loaded_optim.state[p]

        self.writerTbrd = SummaryWriter(boards_dir_1)
        self.output = output_1

        for name, p in self.model.named_parameters():
            if "logit_fc" in name:
                p.requires_grad = True
            else:
                p.requires_grad = False

        print('### Start finetuning step 1...')
        self.train(train_tuple, eval_tuple)

        # Second step, finetune all
        for name, p in self.model.named_parameters():
            p.requires_grad = True

        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * epochs_2)
            print("Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=lr_2,
                                  warmup=0.1,
                                  t_total=t_total,
                                  lr_min=1e-7)
        else:
            self.optim = args.optimizer(list(self.model.parameters()), lr_2)
        args.epochs = epochs_2
        self.writerTbrd = SummaryWriter(boards_dir_2)
        self.output = output_2

        print('### Start finetuning step 2...')
        self.train(train_tuple, eval_tuple)
Exemplo n.º 19
0
class GQA:
    def __init__(self):
        self.train_tuple = get_tuple(args.train,
                                     bs=args.batch_size,
                                     shuffle=True,
                                     drop_last=True)
        if args.valid != "":
            valid_bsize = 2048 if args.multiGPU else 512
            self.valid_tuple = get_tuple(args.valid,
                                         bs=valid_bsize,
                                         shuffle=False,
                                         drop_last=False,
                                         skip_semantics=True)
        else:
            self.valid_tuple = None

        self.model = GQAModel(self.train_tuple.dataset.num_answers)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Losses and optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1)
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
            if args.task_nsp_qfpm or args.task_mlm_qfpm:
                self.task_optim = BertAdam(list(self.model.parameters()),
                                           lr=args.lr,
                                           warmup=0.1,
                                           t_total=t_total)
        else:
            self.optim = args.optimizer(list(self.model.parameters()), args.lr)

        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        # Pre train the NSP head
        if args.task_nsp_qfpm or args.task_mlm_qfpm:
            print(f"********* Pretraining for {args.epochs} epochs *********")
            for epoch in range(args.epochs):
                epoch_pretrain_loss = 0
                epoch_nsp_loss = 0
                epoch_mlm_loss = 0
                epoch_nsp_avg = []
                for i, (ques_id, feats, boxes, sent, sem_query, sem_matched,
                        target) in iter_wrapper(enumerate(loader)):
                    self.model.train()
                    self.task_optim.zero_grad()

                    feats, boxes, sem_matched, target = feats.cuda(
                    ), boxes.cuda(), sem_matched.cuda(), target.cuda()
                    _, logit_nsp_qfpm, logit_mlm_qfpm, masked_lang_labels = self.model(
                        feats, boxes, sent, sem_query)

                    loss = 0
                    if args.task_nsp_qfpm:
                        nsp_qfpm_loss = self.mce_loss(logit_nsp_qfpm,
                                                      sem_matched)
                        loss += 100 * nsp_qfpm_loss  # multiply to equally weight the MLM with NSP loss
                        epoch_nsp_loss += nsp_qfpm_loss.detach()
                        _, idx = torch.max(logit_nsp_qfpm, 1)
                        diff = torch.abs(sem_matched - idx)
                        epoch_nsp_avg.append(
                            torch.sum(diff).item() / sem_matched.shape[0])

                    if args.task_mlm_qfpm:
                        # masked_lang_labels: [batch, max_length], logit_mlm_qfpm: [batch, max_length, vocab_size]
                        vocab_size = logit_mlm_qfpm.shape[2]
                        masked_lm_loss = self.mce_loss(
                            logit_mlm_qfpm.view(-1, vocab_size),
                            masked_lang_labels.view(-1))
                        loss += masked_lm_loss
                        epoch_mlm_loss += masked_lm_loss.detach()

                    loss.backward()
                    nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                    self.task_optim.step()
                    epoch_pretrain_loss += loss.detach()

                print(
                    f"Total loss for epoch = {epoch_pretrain_loss} ... NSP = {epoch_nsp_loss} ... MLM = {epoch_mlm_loss}"
                )
                print(
                    f"NSP: average of average ....... {sum(epoch_nsp_avg)/len(epoch_nsp_avg)}"
                )

        best_valid = 0.
        args.task_nsp_qfpm = False
        args.task_mlm_qfpm = False
        if args.no_fp_train is True:
            loader.dataset.skip_semantics = True

        print(f"********* Finetuning for {args.epochs} epochs *********")
        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, (ques_id, feats, boxes, sent, sem_query, _,
                    target) in iter_wrapper(enumerate(loader)):
                #                 sb = [q + ' ** ' + sq for q, sq in zip(sent, sem_query)]
                #                 print(sb[:10])
                self.model.train()
                self.optim.zero_grad()

                feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda(
                )
                logit_qa = self.model(feats, boxes, sent, sem_query)
                assert logit_qa.dim() == target.dim() == 2
                if args.mce_loss:
                    max_value, target = target.max(1)
                    loss = self.mce_loss(logit_qa, target) * logit_qa.size(1)
                else:
                    loss = self.bce_loss(logit_qa, target)
                    loss = loss * logit_qa.size(1)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                score, label = logit_qa.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid] = ans

            log_str = "\nEpoch %d: Train %0.2f\n" % (
                epoch, evaluator.evaluate(quesid2ans) * 100.)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")

    def predict(self, eval_tuple: DataTuple, dump=None):
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent, sem_query = datum_tuple[:
                                                                 5]  # avoid handling target
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                logit_qa = self.model(feats, boxes, sent, sem_query)
                score, label = logit_qa.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid] = ans
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        dset, loader, evaluator = eval_tuple
        quesid2ans = self.predict(eval_tuple, dump)
        return evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, sem_query, sem_matched,
                target) in enumerate(loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        for key in list(state_dict.keys()):
            if '.module' in key:
                state_dict[key.replace('.module', '')] = state_dict.pop(key)
        self.model.load_state_dict(state_dict, strict=False)
class VQA:
    def __init__(self, attention=False):
        # Datasets
        print("Fetching data")
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True,
                                          dataset_name="test")
        print("Got data")
        print("fetching val data")
        if args.valid != "":
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=args.batch_size,
                                              shuffle=False,
                                              drop_last=False,
                                              dataset_name="test")
            print("got data")
        else:
            self.valid_tuple = None
        print("Got data")

        # Model
        print("Making model")
        self.model = VQAModel(self.train_tuple.dataset.num_answers, attention)
        print("Ready model")
        # Print model info:
        print("Num of answers:")
        print(self.train_tuple.dataset.num_answers)
        # print("Model info:")
        # print(self.model)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple):
        log_freq = 810
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        best_valid = 0.
        flag = True
        for epoch in range(args.epochs):
            quesid2ans = {}
            correct = 0
            total_loss = 0
            total = 0
            print("Len of the dataloader: ", len(loader))
            #             Our new TGIFQA-Dataset returns:
            #             return gif_tensor, self.questions[i], self.ans2id[self.answer[i]]
            for i, (feats1, feats2, sent,
                    target) in iter_wrapper(enumerate(loader)):
                ques_id, boxes = -1, None
                self.model.train()
                self.optim.zero_grad()

                feats1, feats2, target = feats1.cuda(), feats2.cuda(
                ), target.cuda()
                feats = [feats1, feats2]

                logit = self.model(feats, boxes, sent)
                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)

                total_loss += loss.item()

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                score, label = logit.max(1)
                score_t, target = target.max(1)
                correct += (label == target).sum().cpu().numpy()
                total += len(label)
                #if epoch > -1:
                #for l,s,t in zip(label, sent, target):
                #    print(l)
                #    print(s)
                #    print("Prediction", loader.dataset.label2ans[int(l.cpu().numpy())])
                #    print("Answer", loader.dataset.label2ans[int(t.cpu().numpy())])

                if i % log_freq == 1 and i > 1:
                    results = []
                    for l, s, t in zip(label, sent, target):
                        result = []
                        result.append(s)
                        result.append("Prediction: {}".format(
                            loader.dataset.label2ans[int(l.cpu().numpy())]))
                        result.append("Answer: {}".format(
                            loader.dataset.label2ans[int(t.cpu().numpy())]))
                        results.append(result)
                        torch.cuda.empty_cache()
                    val_loss, val_acc, val_results = self.val(eval_tuple)
                    logger.log(total_loss / total, correct / total * 100,
                               val_loss, val_acc, epoch, results, val_results)

            print("==" * 30)
            print("Accuracy = ", correct / total * 100)
            print("Loss =", total_loss / total)
            print("==" * 30)
            #             log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.)

            #             if self.valid_tuple is not None:  # Do Validation
            #                 valid_score = self.evaluate(eval_tuple)
            #                 if valid_score > best_valid:
            #                     best_valid = valid_score
            #                     self.save("BEST")

            #                 log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
            #                            "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)

            #             print(log_str, end='')

            #             with open(self.output + "/log.log", 'a') as f:
            #                 f.write(log_str)
            #                 f.flush()

            self.save("Check" + str(epoch))

    def val(self, eval_tuple):
        dset, loader, evaluator = eval_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)
        self.model.eval()
        best_valid = 0.
        flag = True
        quesid2ans = {}
        correct = 0
        total_loss = 0
        total = 0
        results = []
        print("Len of the dataloader: ", len(loader))
        #             Our new TGIFQA-Dataset returns:
        #             return gif_tensor, self.questions[i], self.ans2id[self.answer[i]]
        with torch.no_grad():
            for i, (feats1, feats2, sent,
                    target) in iter_wrapper(enumerate(loader)):
                ques_id, boxes = -1, None

                feats1, feats2, target = feats1.cuda(), feats2.cuda(
                ), target.cuda()
                feats = [feats1, feats2]

                logit = self.model(feats, boxes, sent)
                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)

                total_loss += loss.item()

                score, label = logit.max(1)
                score_t, target = target.max(1)
                correct += (label == target).sum().cpu().numpy()
                total += len(label)
                for l, s, t in zip(label, sent, target):
                    result = []
                    result.append(s)
                    result.append("Prediction: {}".format(
                        loader.dataset.label2ans[int(l.cpu().numpy())]))
                    result.append("Answer: {}".format(
                        loader.dataset.label2ans[int(t.cpu().numpy())]))
                    results.append(result)
            return total_loss / total, correct / total * 100, results

    def predict(self, eval_tuple: DataTuple, dump=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent = datum_tuple[:
                                                      4]  # Avoid seeing ground truth
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(feats, boxes, sent)
                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        quesid2ans = self.predict(eval_tuple, dump)
        return eval_tuple.evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)