class VQA:
    def __init__(self,folder="/",load=True):
        # Datasets
        if load:
            self.train_tuple = get_data_tuple(
                args.train, bs=args.batch_size, shuffle=True, drop_last=True,folder=folder
            )
            if args.valid != "":
                self.valid_tuple = get_data_tuple(
                    args.valid, bs=128,
                    shuffle=False, drop_last=False, folder=folder
                )
            else:
                self.valid_tuple = None
        
        # Model
#         self.model = VQAModel(self.train_tuple.dataset.num_answers)
        self.model = VQAModel(3129)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa, self.model,
                           label2ans=self.train_tuple.dataset.label2ans)
        
        # GPU options
        self.model = self.model.cuda()
        
        
        self.yes_index=425
        self.no_index=1403
        
        self.mask_yes = torch.zeros(len(self.indexlist)).cuda()
        self.mask_yes[self.yes_index]=1.0
        self.mask_yes[self.no_index]=1.0
        
        
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if load :
            if 'bert' in args.optim:
                batch_per_epoch = len(self.train_tuple.loader)
                t_total = int(batch_per_epoch * args.epochs)
                print("BertAdam Total Iters: %d" % t_total)
                from lxrt.optimization import BertAdam
                self.optim = BertAdam(list(self.model.parameters()),
                                      lr=args.lr,
                                      warmup=0.1,
                                      t_total=t_total)
            else:
                self.optim = args.optimizer(self.model.parameters(), args.lr)
            # Output Directory
            self.output = args.output
            os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader),ascii=True)) if args.tqdm else (lambda x: x)

        best_valid = 0.
        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, (ques_id, feats, boxes, ques, op, q1, q2, typetarget, q1typetarget, q2typetarget, yesnotypetargets, q1yntypetargets, q2yntypetargets, target, q1_target, q2_target) in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                feats, boxes, target, yntypetarget, typetarget = feats.cuda(), boxes.cuda(), target.cuda(), yesnotypetargets.cuda(), typetarget.cuda()
                
                op, q1typetarget, q2typetarget, q1yntypetargets, q2yntypetargets , q1_target, q2_target = op.cuda(), q1typetarget.cuda(), q2typetarget.cuda(), q1yntypetargets.cuda(), q2yntypetargets.cuda() , q1_target.cuda(), q2_target.cuda()
                
                
                
                logit = self.model(feats, boxes, sent)
                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)
                
                q1logit = self.model(feats, boxes, q1)
                q2logit = self.model(feats, boxes, q2)
                
                constraint_loss = self.constraint_loss(logit,q1logit,q2logit,op)

                loss = 0.5*loss + 0.5*constraint_loss

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans

            log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")
        return best_valid
    
    def rangeloss(self,x,lower,upper,lamb=4):
        mean = (lower+upper)/2
        sigma =  (upper-lower+0.00001)/lamb
        loss = 1 - torch.exp(-0.5*torch.pow(torch.div(x-mean,sigma),2))
        return loss.sum()
    
    def select_yesnoprobs(self,logit,x,op):
        op_mask = torch.eq(op,x)
        logit = logit[op_mask].view(-1,3129)
        logit_m =  logit * self.mask_yes 
        m = logit_m == 0
        logit_m = logit_m[~m].view(-1,2)
        logit_m = torch.softmax(logit_m,1)
        return logit_m.select(dim=1,index=0).view(-1,1)
        
    
    def constraintloss(self,logit,q1_logit,q2_logit,op):
        total_loss=torch.zeros([1]).cuda()
        for x in range(1,11):
            logit_m= self.select_yesnoprobs(logit,x,op)
            q1_logit_m= self.select_yesnoprobs(q1_logit,x,op)
            q2_logit_m= self.select_yesnoprobs(q2_logit,x,op)
            
            if logit_m.nelement()==0:
                continue

            ideal_logit_m = op_map[x](q1_logit_m,q2_logit_m)
            rangeloss = self.mseloss(logit_m,ideal_logit_m)
            total_loss+=rangeloss
        return total_loss
    

    def predict(self, eval_tuple: DataTuple, dump=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        for i, datum_tuple in tqdm(enumerate(loader),ascii=True,desc="Evaluating"):
#             ques_id, feats, boxes, sent = datum_tuple[:4]   # Avoid seeing ground truth
            ques_id, feats, boxes, ques, op, q1, q2, typetarget, q1typetarget, q2typetarget, yesnotypetargets, q1yntypetargets, q2yntypetargets, target, q1_target, q2_target = datum_tuple
    
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(feats, boxes, sent)
                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        quesid2ans = self.predict(eval_tuple, dump)
        return eval_tuple.evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, datum_tuple in enumerate(loader):
             ques_id, feats, boxes, ques, op, q1, q2, typetarget, q1typetarget, q2typetarget, yesnotypetargets, q1yntypetargets, q2yntypetargets, target, q1_target, q2_target = datum_tuple
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)
Example #2
0
class VQA:
    def __init__(self):
        # Datasets
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True)
        if args.valid != "":
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=1024,
                                              shuffle=False,
                                              drop_last=False)
        else:
            self.valid_tuple = None

        # Model
        self.model = VQAModel(self.train_tuple.dataset.num_answers)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        best_valid = 0.
        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, (ques_id, feats, boxes, sent,
                    target) in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda(
                )
                logit = self.model(feats, boxes, sent)
                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans

            log_str = "\nEpoch %d: Train %0.2f\n" % (
                epoch, evaluator.evaluate(quesid2ans) * 100.)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")

    def predict(self, eval_tuple: DataTuple, dump=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent = datum_tuple[:
                                                      4]  # Avoid seeing ground truth
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(feats, boxes, sent)
                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        quesid2ans = self.predict(eval_tuple, dump)
        return eval_tuple.evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)
Example #3
0
class VQA:
    def __init__(self):
        # Datasets
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True)
        if args.valid != "":
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=1024,
                                              shuffle=False,
                                              drop_last=False)
        else:
            self.valid_tuple = None

        # Model
        self.model = VQAModel(self.train_tuple.dataset.num_answers)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        best_valid = 0.
        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, (ques_id, feats, boxes, sent,
                    target) in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda(
                )
                logit = self.model(feats, boxes, sent)
                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans

            log_str = "\nEpoch %d: Train %0.2f\n" % (
                epoch, evaluator.evaluate(quesid2ans) * 100.)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")

    def predict(self, eval_tuple: DataTuple, dump=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple

        question_id2img_id = {x["question_id"]: x["img_id"] for x in dset.data}
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased",
                                                  do_lower_case=True)
        plt.rcParams['figure.figsize'] = (12, 10)
        num_regions = 36

        count = 0

        quesid2ans = {}
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent = datum_tuple[:
                                                      4]  # Avoid seeing ground truth
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(feats, boxes, sent)

                for layer in [0, 4]:
                    for head in [0, 1]:
                        for datapoint in range(len(sent)):
                            print(count, len(sent))
                            count += 1
                            lang2vis_attention_probs = self.model.lxrt_encoder.model.bert.encoder.x_layers[
                                layer].lang_att_map[datapoint][head].detach(
                                ).cpu().numpy()

                            vis2lang_attention_probs = self.model.lxrt_encoder.model.bert.encoder.x_layers[
                                layer].visn_att_map[datapoint][head].detach(
                                ).cpu().numpy()

                            plt.clf()

                            plt.subplot(2, 3, 1)
                            plt.gca().set_axis_off()
                            plt.title("Image (regions 0-7)")
                            im = cv2.imread(
                                os.path.join(
                                    "/mnt/8tera/claudio.greco/mscoco_trainval_2014",
                                    question_id2img_id[
                                        ques_id[datapoint].item()]) + ".jpg")
                            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                            plt.imshow(im)

                            plt.subplot(2, 3, 2)
                            plt.gca().set_axis_off()
                            plt.title("Image (regions 8-15)")
                            im = cv2.imread(
                                os.path.join(
                                    "/mnt/8tera/claudio.greco/mscoco_trainval_2014",
                                    question_id2img_id[
                                        ques_id[datapoint].item()]) + ".jpg")
                            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                            plt.imshow(im)

                            plt.subplot(2, 3, 3)
                            plt.gca().set_axis_off()
                            plt.title("Image (regions 16-35)")
                            im = cv2.imread(
                                os.path.join(
                                    "/mnt/8tera/claudio.greco/mscoco_trainval_2014",
                                    question_id2img_id[
                                        ques_id[datapoint].item()]) + ".jpg")
                            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                            plt.imshow(im)

                            img_info = loader.dataset.imgid2img[
                                question_id2img_id[ques_id[datapoint].item()]]
                            img_h, img_w = img_info['img_h'], img_info['img_w']
                            unnormalized_boxes = boxes[datapoint].clone()
                            unnormalized_boxes[:, (0, 2)] *= img_w
                            unnormalized_boxes[:, (1, 3)] *= img_h

                            for i, bbox in enumerate(unnormalized_boxes):
                                if i < 8:
                                    plt.subplot(2, 3, 1)
                                elif i < 16:
                                    plt.subplot(2, 3, 2)
                                else:
                                    plt.subplot(2, 3, 3)

                                bbox = [
                                    bbox[0].item(), bbox[1].item(),
                                    bbox[2].item(), bbox[3].item()
                                ]

                                if bbox[0] == 0:
                                    bbox[0] = 2
                                if bbox[1] == 0:
                                    bbox[1] = 2

                                plt.gca().add_patch(
                                    plt.Rectangle((bbox[0], bbox[1]),
                                                  bbox[2] - bbox[0] - 4,
                                                  bbox[3] - bbox[1] - 4,
                                                  fill=False,
                                                  edgecolor='red',
                                                  linewidth=1))

                                plt.gca().text(bbox[0],
                                               bbox[1] - 2,
                                               '%s' % i,
                                               bbox=dict(facecolor='blue'),
                                               fontsize=9,
                                               color='white')

                            ax = plt.subplot(2, 1, 2)
                            plt.title("Cross-modal attention lang2vis")

                            tokenized_question = tokenizer.tokenize(
                                sent[datapoint])
                            tokenized_question = [
                                "<CLS>"
                            ] + tokenized_question + ["<SEP>"]

                            transposed_attention_map = lang2vis_attention_probs[:len(
                                tokenized_question), :num_regions]
                            im = plt.imshow(transposed_attention_map,
                                            vmin=0,
                                            vmax=1)

                            for i in range(len(tokenized_question)):
                                for j in range(num_regions):
                                    att_value = round(
                                        transposed_attention_map[i, j], 1)
                                    text = ax.text(
                                        j,
                                        i,
                                        att_value,
                                        ha="center",
                                        va="center",
                                        color="w" if att_value <= 0.5 else "b",
                                        fontsize=6)

                            ax.set_xticks(np.arange(num_regions))
                            ax.set_xticklabels(list(range(num_regions)))

                            ax.set_yticks(np.arange(len(tokenized_question)))
                            ax.set_yticklabels(tokenized_question)

                            plt.tight_layout()
                            # plt.gca().set_axis_off()
                            plt.savefig(
                                "/mnt/8tera/claudio.greco/guesswhat_lxmert/guesswhat/visualization_vqa/lang2vis_question_{}_layer_{}_head_{}.png"
                                .format(ques_id[datapoint].item(), layer,
                                        head),
                                bbox_inches='tight',
                                pad_inches=0.5)

                            plt.close()

                            ## vis2lang

                            plt.clf()

                            plt.subplot(2, 3, 1)
                            plt.gca().set_axis_off()
                            plt.title("Image (regions 0-7)")
                            im = cv2.imread(
                                os.path.join(
                                    "/mnt/8tera/claudio.greco/mscoco_trainval_2014",
                                    question_id2img_id[
                                        ques_id[datapoint].item()]) + ".jpg")
                            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                            plt.imshow(im)

                            plt.subplot(2, 3, 2)
                            plt.gca().set_axis_off()
                            plt.title("Image (regions 8-15)")
                            im = cv2.imread(
                                os.path.join(
                                    "/mnt/8tera/claudio.greco/mscoco_trainval_2014",
                                    question_id2img_id[
                                        ques_id[datapoint].item()]) + ".jpg")
                            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                            plt.imshow(im)

                            plt.subplot(2, 3, 3)
                            plt.gca().set_axis_off()
                            plt.title("Image (regions 16-35)")
                            im = cv2.imread(
                                os.path.join(
                                    "/mnt/8tera/claudio.greco/mscoco_trainval_2014",
                                    question_id2img_id[
                                        ques_id[datapoint].item()]) + ".jpg")
                            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                            plt.imshow(im)

                            img_info = loader.dataset.imgid2img[
                                question_id2img_id[ques_id[datapoint].item()]]
                            img_h, img_w = img_info['img_h'], img_info['img_w']
                            unnormalized_boxes = boxes[datapoint].clone()
                            unnormalized_boxes[:, (0, 2)] *= img_w
                            unnormalized_boxes[:, (1, 3)] *= img_h

                            for i, bbox in enumerate(unnormalized_boxes):
                                if i < 8:
                                    plt.subplot(2, 3, 1)
                                elif i < 16:
                                    plt.subplot(2, 3, 2)
                                else:
                                    plt.subplot(2, 3, 3)

                                bbox = [
                                    bbox[0].item(), bbox[1].item(),
                                    bbox[2].item(), bbox[3].item()
                                ]

                                if bbox[0] == 0:
                                    bbox[0] = 2
                                if bbox[1] == 0:
                                    bbox[1] = 2

                                plt.gca().add_patch(
                                    plt.Rectangle((bbox[0], bbox[1]),
                                                  bbox[2] - bbox[0] - 4,
                                                  bbox[3] - bbox[1] - 4,
                                                  fill=False,
                                                  edgecolor='red',
                                                  linewidth=1))

                                plt.gca().text(bbox[0],
                                               bbox[1] - 2,
                                               '%s' % i,
                                               bbox=dict(facecolor='blue'),
                                               fontsize=9,
                                               color='white')

                            ax = plt.subplot(2, 1, 2)
                            plt.title("Cross-modal attention vis2lang")

                            tokenized_question = tokenizer.tokenize(
                                sent[datapoint])
                            tokenized_question = [
                                "<CLS>"
                            ] + tokenized_question + ["<SEP>"]

                            transposed_attention_map = vis2lang_attention_probs.transpose(
                            )[:len(tokenized_question), :num_regions]
                            im = plt.imshow(transposed_attention_map,
                                            vmin=0,
                                            vmax=1)

                            for i in range(len(tokenized_question)):
                                for j in range(num_regions):
                                    att_value = round(
                                        transposed_attention_map[i, j], 1)
                                    text = ax.text(
                                        j,
                                        i,
                                        att_value,
                                        ha="center",
                                        va="center",
                                        color="w" if att_value <= 0.5 else "b",
                                        fontsize=6)

                            ax.set_xticks(np.arange(num_regions))
                            ax.set_xticklabels(list(range(num_regions)))

                            ax.set_yticks(np.arange(len(tokenized_question)))
                            ax.set_yticklabels(tokenized_question)

                            plt.tight_layout()
                            # plt.gca().set_axis_off()
                            plt.savefig(
                                "/mnt/8tera/claudio.greco/guesswhat_lxmert/guesswhat/visualization_vqa/vis2lang_question_{}_layer_{}_head_{}.png"
                                .format(ques_id[datapoint].item(), layer,
                                        head),
                                bbox_inches='tight',
                                pad_inches=0.5)

                            plt.close()

                            # print(datapoint, len(sent))
                    #
                    #         print(datapoint)
                    #         if datapoint > 20:
                    #             break
                    #     if datapoint > 20:
                    #         break
                    # if datapoint > 20:
                    #     break

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        quesid2ans = self.predict(eval_tuple, dump)
        return eval_tuple.evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)
Example #4
0
class VQA:
    def __init__(self):
        # Datasets
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True)
        if args.valid != "":
            valid_bsize = args.get("valid_batch_size", 16)
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=valid_bsize,
                                              shuffle=False,
                                              drop_last=False)
        else:
            self.valid_tuple = None

        # Model
        self.model = VQAModel(self.train_tuple.dataset.num_answers)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.get("load_lxmert_pretrain", None) is not None:
            load_lxmert_from_pretrain_noqa(args.load_lxmert_pretrain,
                                           self.model)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()
            self.model.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        best_valid = 0.
        train_results = []
        report_every = args.get("report_every", 100)
        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, batch in iter_wrapper(enumerate(loader)):
                ques_id, feats, boxes, sent, tags, target = zip(*batch)
                self.model.train()
                self.optim.zero_grad()

                target = torch.stack(target).cuda()
                logit = self.model(feats, boxes, sent, tags)
                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()
                train_results.append(
                    pd.Series({"loss": loss.detach().mean().item()}))

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid] = ans

                if i % report_every == 0 and i > 0:
                    print("Epoch: {}, Iter: {}/{}".format(
                        epoch, i, len(loader)))
                    print("    {}\n~~~~~~~~~~~~~~~~~~\n".format(
                        pd.DataFrame(train_results[-report_every:]).mean()))

            log_str = "\nEpoch %d: Train %0.2f\n" % (
                epoch, evaluator.evaluate(quesid2ans) * 100.)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                if valid_score > best_valid and not args.get(
                        "special_test", False):
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)
            if epoch >= 5:
                self.save("Epoch{}".format(epoch))
            print(log_str, end='')
            print(args.output)

        self.save("LAST")

    def predict(self, eval_tuple: DataTuple, dump=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        for i, batch in enumerate(tqdm(loader)):
            _ = list(zip(*batch))
            ques_id, feats, boxes, sent, tags = _[:5]  #, target = zip(*batch)
            with torch.no_grad():
                #target = torch.stack(target).cuda()
                logit = self.model(feats, boxes, sent, tags)
                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid] = ans
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        quesid2ans = self.predict(eval_tuple, dump)
        return eval_tuple.evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)
class VQA:
    def __init__(self):
        # Datasets
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True)
        if args.valid != "":
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=1024,
                                              shuffle=False,
                                              drop_last=False)
        else:
            self.valid_tuple = None

        # Model
        self.model = VQAModel(self.train_tuple.dataset.num_answers,
                              finetune_strategy=args.finetune_strategy)

        # if finetune strategy is spottune
        if args.finetune_strategy in PolicyStrategies:
            self.policy_model = PolicyLXRT(
                PolicyStrategies[args.finetune_strategy])

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.finetune_strategy in PolicyStrategies:
            self.policy_model = self.policy_model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()
            self.policy_model.policy_lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Optimizer for policy net
        if args.finetune_strategy in PolicyStrategies:
            self.policy_optim = args.policy_optimizer(
                self.policy_model.parameters(), args.policy_lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple, visualizer=None):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        wandb.watch(self.model, log='all')
        if args.finetune_strategy in PolicyStrategies:
            wandb.watch(self.policy_model, log='all')

        best_valid = 0.

        for epoch in range(args.epochs):
            # for policy vec plotting
            if args.finetune_strategy in PolicyStrategies:
                policy_save = torch.zeros(
                    PolicyStrategies[args.finetune_strategy] // 2).cpu()
                policy_max = 0

            quesid2ans = {}
            for i, (ques_id, feats, boxes, sent,
                    target) in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                if args.finetune_strategy in PolicyStrategies:
                    self.policy_model.train()
                    self.policy_optim.zero_grad()

                feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda(
                )

                if args.finetune_strategy in PolicyStrategies:
                    # calculate the policy vector here
                    policy_vec = self.policy_model(feats, boxes, sent)
                    policy_action = gumbel_softmax(
                        policy_vec.view(policy_vec.size(0), -1, 2))
                    policy = policy_action[:, :, 1]
                    policy_save = policy_save + policy.clone().detach().cpu(
                    ).sum(0)
                    policy_max += policy.size(0)
                    logit = self.model(feats, boxes, sent, policy)
                else:
                    logit = self.model(feats, boxes, sent)

                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()
                if args.finetune_strategy in PolicyStrategies:
                    self.policy_optim.step()

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans

            # check if visualizer is not none
            if visualizer is not None:
                print(f'Creating training visualizations for epoch {epoch}')
                visualizer.plot(policy_save,
                                policy_max,
                                epoch=epoch,
                                mode='train')

            train_acc = evaluator.evaluate(quesid2ans) * 100.
            log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, train_acc)

            wandb.log({'Training Accuracy': train_acc})

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple,
                                            epoch=epoch,
                                            visualizer=visualizer)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)

                wandb.log({'Validation Accuracy': valid_score * 100.})

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")

    def predict(self,
                eval_tuple: DataTuple,
                dump=None,
                epoch=0,
                visualizer=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        if args.finetune_strategy in PolicyStrategies:
            self.policy_model.eval()
            policy_save = torch.zeros(
                PolicyStrategies[args.finetune_strategy] // 2)
            policy_max = 0

        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent = datum_tuple[:
                                                      4]  # Avoid seeing ground truth
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                if args.finetune_strategy in PolicyStrategies:
                    # calculate the policy vector here
                    policy_vec = self.policy_model(feats, boxes, sent)
                    policy_action = gumbel_softmax(
                        policy_vec.view(policy_vec.size(0), -1, 2))
                    policy = policy_action[:, :, 1]
                    policy_save = policy_save + policy.clone().detach().cpu(
                    ).sum(0)
                    policy_max += policy.size(0)
                    logit = self.model(feats, boxes, sent, policy)
                else:
                    logit = self.model(feats, boxes, sent)

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans

        if visualizer is not None:
            print(f'Creating validation visualization for epoch {epoch}...')
            visualizer.plot(policy_save, policy_max, epoch=epoch, mode='val')

        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def evaluate(self,
                 eval_tuple: DataTuple,
                 dump=None,
                 epoch=0,
                 visualizer=None):
        """Evaluate all data in data_tuple."""
        quesid2ans = self.predict(eval_tuple,
                                  dump,
                                  epoch=epoch,
                                  visualizer=visualizer)
        return eval_tuple.evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)
Example #6
0
class VQA:
    def __init__(self):
        # Datasets
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True)
        if args.valid != "":
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=1024,
                                              shuffle=False,
                                              drop_last=False)
        else:
            self.valid_tuple = None

        # Model
        self.model = VQAModel(self.train_tuple.dataset.num_answers)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def train(self,
              train_tuple,
              eval_tuple,
              adversarial=False,
              adv_batch_prob=0.0,
              attack_name=None,
              attack_params={}):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)
        use_adv_batch = False

        best_valid = 0.

        for epoch in range(args.epochs):
            quesid2ans = {}
            # Count the number of batches that were adversarially perturbed
            n_adv_batches = 0
            for i, (ques_id, feats, boxes, sent,
                    target) in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda(
                )

                # If doing adversarial training, perturb input features
                # with probability adv_batch_prob
                if adversarial:
                    rand = random.uniform(0, 1)
                    use_adv_batch = rand <= adv_batch_prob
                if use_adv_batch:
                    # Create adversary from given class name and parameters
                    n_adv_batches += 1
                    AdversaryClass_ = getattr(advertorch_module, attack_name)
                    adversary = AdversaryClass_(
                        lambda x: self.model(x, boxes, sent),
                        loss_fn=self.bce_loss,
                        **attack_params)
                    # Perturb feats using adversary
                    feats = adversary.perturb(feats, target)

                logit = self.model(feats, boxes, sent)
                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans

            log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.) + \
                        "Epoch %d: Num adversarial batches %d / %d\n" % (epoch, n_adv_batches, i+1)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")

    def predict(self, eval_tuple: DataTuple, dump=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent = datum_tuple[:
                                                      4]  # Avoid seeing ground truth
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(feats, boxes, sent)
                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def adversarial_predict(self,
                            eval_tuple: DataTuple,
                            dump=None,
                            attack_name='GradientAttack',
                            attack_params={}):
        """
        Predict the answers to questions in a data split, but
        using a specified adversarial attack on the inputs.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        sim_trace = []  # Track avg cos similarity across batches
        for i, datum_tuple in enumerate(tqdm(loader)):
            ques_id, feats, boxes, sent, target = datum_tuple
            feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda()

            # Create adversary from given class name and parameters
            AdversaryClass_ = getattr(advertorch_module, attack_name)
            adversary = AdversaryClass_(lambda x: self.model(x, boxes, sent),
                                        loss_fn=self.bce_loss,
                                        **attack_params)

            # Perturb feats using adversary
            feats_adv = adversary.perturb(feats, target)

            # Compute average cosine similarity between true
            # and perturbed features
            sim_trace.append(self.avg_cosine_sim(feats, feats_adv))

            # Compute prediction on adversarial examples
            with torch.no_grad():
                feats_adv = feats_adv.cuda()
                logit = self.model(feats_adv, boxes, sent)
                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        print(
            f"Average cosine similarity across batches: {torch.mean(torch.Tensor(sim_trace))}"
        )
        return quesid2ans

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        quesid2ans = self.predict(eval_tuple, dump)
        return eval_tuple.evaluator.evaluate(quesid2ans)

    def adversarial_evaluate(self,
                             eval_tuple: DataTuple,
                             dump=None,
                             attack_name='GradientAttack',
                             attack_params={}):
        """Evaluate model on adversarial inputs"""
        quesid2ans = self.adversarial_predict(eval_tuple, dump, attack_name,
                                              attack_params)
        return eval_tuple.evaluator.evaluate(quesid2ans)

    def avg_cosine_sim(self, feats: torch.Tensor, feats_adv: torch.Tensor):
        """Computes the average cosine similarity between true and adversarial examples"""
        return nn.functional.cosine_similarity(feats, feats_adv, dim=-1).mean()

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)
class VQA:
    def __init__(self, attention=False):
        # Datasets
        print("Fetching data")
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True,
                                          dataset_name="test")
        print("Got data")
        print("fetching val data")
        if args.valid != "":
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=args.batch_size,
                                              shuffle=False,
                                              drop_last=False,
                                              dataset_name="test")
            print("got data")
        else:
            self.valid_tuple = None
        print("Got data")

        # Model
        print("Making model")
        self.model = VQAModel(self.train_tuple.dataset.num_answers, attention)
        print("Ready model")
        # Print model info:
        print("Num of answers:")
        print(self.train_tuple.dataset.num_answers)
        # print("Model info:")
        # print(self.model)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

    def train(self, train_tuple, eval_tuple):
        log_freq = 810
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        best_valid = 0.
        flag = True
        for epoch in range(args.epochs):
            quesid2ans = {}
            correct = 0
            total_loss = 0
            total = 0
            print("Len of the dataloader: ", len(loader))
            #             Our new TGIFQA-Dataset returns:
            #             return gif_tensor, self.questions[i], self.ans2id[self.answer[i]]
            for i, (feats1, feats2, sent,
                    target) in iter_wrapper(enumerate(loader)):
                ques_id, boxes = -1, None
                self.model.train()
                self.optim.zero_grad()

                feats1, feats2, target = feats1.cuda(), feats2.cuda(
                ), target.cuda()
                feats = [feats1, feats2]

                logit = self.model(feats, boxes, sent)
                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)

                total_loss += loss.item()

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()

                score, label = logit.max(1)
                score_t, target = target.max(1)
                correct += (label == target).sum().cpu().numpy()
                total += len(label)
                #if epoch > -1:
                #for l,s,t in zip(label, sent, target):
                #    print(l)
                #    print(s)
                #    print("Prediction", loader.dataset.label2ans[int(l.cpu().numpy())])
                #    print("Answer", loader.dataset.label2ans[int(t.cpu().numpy())])

                if i % log_freq == 1 and i > 1:
                    results = []
                    for l, s, t in zip(label, sent, target):
                        result = []
                        result.append(s)
                        result.append("Prediction: {}".format(
                            loader.dataset.label2ans[int(l.cpu().numpy())]))
                        result.append("Answer: {}".format(
                            loader.dataset.label2ans[int(t.cpu().numpy())]))
                        results.append(result)
                        torch.cuda.empty_cache()
                    val_loss, val_acc, val_results = self.val(eval_tuple)
                    logger.log(total_loss / total, correct / total * 100,
                               val_loss, val_acc, epoch, results, val_results)

            print("==" * 30)
            print("Accuracy = ", correct / total * 100)
            print("Loss =", total_loss / total)
            print("==" * 30)
            #             log_str = "\nEpoch %d: Train %0.2f\n" % (epoch, evaluator.evaluate(quesid2ans) * 100.)

            #             if self.valid_tuple is not None:  # Do Validation
            #                 valid_score = self.evaluate(eval_tuple)
            #                 if valid_score > best_valid:
            #                     best_valid = valid_score
            #                     self.save("BEST")

            #                 log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
            #                            "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)

            #             print(log_str, end='')

            #             with open(self.output + "/log.log", 'a') as f:
            #                 f.write(log_str)
            #                 f.flush()

            self.save("Check" + str(epoch))

    def val(self, eval_tuple):
        dset, loader, evaluator = eval_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)
        self.model.eval()
        best_valid = 0.
        flag = True
        quesid2ans = {}
        correct = 0
        total_loss = 0
        total = 0
        results = []
        print("Len of the dataloader: ", len(loader))
        #             Our new TGIFQA-Dataset returns:
        #             return gif_tensor, self.questions[i], self.ans2id[self.answer[i]]
        with torch.no_grad():
            for i, (feats1, feats2, sent,
                    target) in iter_wrapper(enumerate(loader)):
                ques_id, boxes = -1, None

                feats1, feats2, target = feats1.cuda(), feats2.cuda(
                ), target.cuda()
                feats = [feats1, feats2]

                logit = self.model(feats, boxes, sent)
                assert logit.dim() == target.dim() == 2
                loss = self.bce_loss(logit, target)
                loss = loss * logit.size(1)

                total_loss += loss.item()

                score, label = logit.max(1)
                score_t, target = target.max(1)
                correct += (label == target).sum().cpu().numpy()
                total += len(label)
                for l, s, t in zip(label, sent, target):
                    result = []
                    result.append(s)
                    result.append("Prediction: {}".format(
                        loader.dataset.label2ans[int(l.cpu().numpy())]))
                    result.append("Answer: {}".format(
                        loader.dataset.label2ans[int(t.cpu().numpy())]))
                    results.append(result)
            return total_loss / total, correct / total * 100, results

    def predict(self, eval_tuple: DataTuple, dump=None):
        """
        Predict the answers to questions in a data split.

        :param eval_tuple: The data tuple to be evaluated.
        :param dump: The path of saved file to dump results.
        :return: A dict of question_id to answer.
        """
        self.model.eval()
        dset, loader, evaluator = eval_tuple
        quesid2ans = {}
        for i, datum_tuple in enumerate(loader):
            ques_id, feats, boxes, sent = datum_tuple[:
                                                      4]  # Avoid seeing ground truth
            with torch.no_grad():
                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(feats, boxes, sent)
                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid.item()] = ans
        if dump is not None:
            evaluator.dump_result(quesid2ans, dump)
        return quesid2ans

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        quesid2ans = self.predict(eval_tuple, dump)
        return eval_tuple.evaluator.evaluate(quesid2ans)

    @staticmethod
    def oracle_score(data_tuple):
        dset, loader, evaluator = data_tuple
        quesid2ans = {}
        for i, (ques_id, feats, boxes, sent, target) in enumerate(loader):
            _, label = target.max(1)
            for qid, l in zip(ques_id, label.cpu().numpy()):
                ans = dset.label2ans[l]
                quesid2ans[qid.item()] = ans
        return evaluator.evaluate(quesid2ans)

    def save(self, name):
        torch.save(self.model.state_dict(),
                   os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)
        state_dict = torch.load("%s.pth" % path)
        self.model.load_state_dict(state_dict)