Exemplo n.º 1
0
    def __init__(self):
        # Datasets
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True,
                                          dataset_type=args.dataset_type)
        if args.valid != "":
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=128,
                                              shuffle=False,
                                              drop_last=False,
                                              dataset_type=args.dataset_type)
        else:
            self.valid_tuple = None
        # Model
        self.model = VQAModel(
            self.train_tuple.dataset.num_answers
            if not args.transfer_learning else VQADataset.get_answers_number(),
            encoder_type=args.encoder_type)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)
        self.prepare_model()
        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)
Exemplo n.º 2
0
    def __init__(self, folder="/", load=True):
        # Datasets
        if load:
            self.train_tuple = get_data_tuple(args.train,
                                              bs=args.batch_size,
                                              shuffle=True,
                                              drop_last=True,
                                              folder=folder)
            if args.valid != "":
                self.valid_tuple = get_data_tuple(args.valid,
                                                  bs=128,
                                                  shuffle=False,
                                                  drop_last=False,
                                                  folder=folder,
                                                  nops=args.nops)
            else:
                self.valid_tuple = None

        # Model


#         self.model = VQAModel(self.train_tuple.dataset.num_answers)
        is_cp = False
        if "vqacpv2" in folder:
            is_cp = True
        if not is_cp:
            self.model = VQAModel(3129)
        else:
            self.model = VQAModel(16039)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if load:
            if 'bert' in args.optim:
                batch_per_epoch = len(self.train_tuple.loader)
                t_total = int(batch_per_epoch * args.epochs)
                print("BertAdam Total Iters: %d" % t_total)
                from lxrt.optimization import BertAdam
                self.optim = BertAdam(list(self.model.parameters()),
                                      lr=args.lr,
                                      warmup=0.1,
                                      t_total=t_total)
            else:
                self.optim = args.optimizer(self.model.parameters(), args.lr)
            # Output Directory
            self.output = args.output
            os.makedirs(self.output, exist_ok=True)
Exemplo n.º 3
0
    def __init__(self):
        # Datasets
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True)
        if args.bert_type == 'ft':
            bs_infer = 256
        else:
            bs_infer = 1024
        if args.valid != "":
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=bs_infer,
                                              shuffle=False,
                                              drop_last=False)
        else:
            self.valid_tuple = None

        print("args.lr is {0}".format(args.lr))

        # Model
        self.model = VQAModel(self.train_tuple.dataset.num_answers)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if 'bert' in args.optim:
            # if type(args.lr) == type("sdfg"):
            #     args.lr = float(args.lr)

            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total,
                                  schedule=args.lr_schedule,
                                  args=args)

        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)
Exemplo n.º 4
0
    def __init__(self):
        self.train_tuple = get_tuple(args.train,
                                     bs=args.batch_size,
                                     shuffle=True,
                                     drop_last=True)
        if args.valid != "":
            valid_bsize = 2048 if args.multiGPU else 512
            self.valid_tuple = get_tuple(args.valid,
                                         bs=valid_bsize,
                                         shuffle=False,
                                         drop_last=False)
        else:
            self.valid_tuple = None

        self.model = GQAModel(self.train_tuple.dataset.num_answers)
        self.momentum = 0.99997
        self.siam_model = copy.deepcopy(self.model)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
            self.siam_model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)
            load_lxmert_qa(args.load_lxmert_qa,
                           self.siam_model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        self.siam_model = self.siam_model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()
            self.siam_model.lxrt_encoder.multi_gpu()

        # Losses and optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1)
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(list(self.model.parameters()), args.lr)

        self.output = args.output
        os.makedirs(self.output, exist_ok=True)
Exemplo n.º 5
0
    def __init__(self):
        # Datasets
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True)
        if args.valid != "":
            valid_bsize = args.get("valid_batch_size", 16)
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=valid_bsize,
                                              shuffle=False,
                                              drop_last=False)
        else:
            self.valid_tuple = None

        # Model
        self.model = VQAModel(self.train_tuple.dataset.num_answers)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.get("load_lxmert_pretrain", None) is not None:
            load_lxmert_from_pretrain_noqa(args.load_lxmert_pretrain,
                                           self.model)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()
            self.model.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)
    def __init__(self,folder="/",load=True):
        # Datasets
        if load:
            self.train_tuple = get_data_tuple(
                args.train, bs=args.batch_size, shuffle=True, drop_last=True,folder=folder
            )
            if args.valid != "":
                self.valid_tuple = get_data_tuple(
                    args.valid, bs=128,
                    shuffle=False, drop_last=False, folder=folder,nops=args.nops
                )
            else:
                self.valid_tuple = None
        
            get_bias(self.train_tuple.dataset,self.valid_tuple.dataset)
        # Model
#         self.model = VQAModel(self.train_tuple.dataset.num_answers)
        label2ans = json.load(open("/data/datasets/vqa_mutant/data/vqa/mutant_l2a/mutant_label2ans.json"))
        self.model = VQAModel(len(label2ans))

        self.debias = LearnedMixin(w=0.36,hid_dim=self.model.lxrt_encoder.dim)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa, self.model,
                           label2ans=self.train_tuple.dataset.label2ans)
        
        # GPU options
        self.model = self.model.cuda()
        self.debias = self.debias.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if load :
            if 'bert' in args.optim:
                batch_per_epoch = len(self.train_tuple.loader)
                t_total = int(batch_per_epoch * args.epochs)
                print("BertAdam Total Iters: %d" % t_total)
                from lxrt.optimization import BertAdam
                self.optim = BertAdam(list(self.model.parameters()),
                                      lr=args.lr,
                                      warmup=0.1,
                                      t_total=t_total)
            else:
                self.optim = args.optimizer(self.model.parameters(), args.lr)
            # Output Directory
            self.output = args.output
            os.makedirs(self.output, exist_ok=True)
Exemplo n.º 7
0
    def load_model(self):
        """
        Load the pre-trained VQA model
        """

        print(args.n_head)
        # update data, to allow multiple load_model() calls
        self.data_loader = Demo_data(self.cfg)

        # load answer dict
        with open(self.cfg['answers_dict'], 'r') as f:
            self.label_to_ans = json.load(f)

        # load architecture
        self.model = GQAModel(self.cfg['num_answers'])

        # load pretrained weights
        if self.cfg['ecai_lxmert']:
            if self.cfg['tiny_lxmert']:
                print(
                    "Sorry, there is no tiny version of ecai lxmert. Change config in src/task/demo_cfg.json!"
                )
                exit(0)
            # Load raw pretrained model
            path = self.cfg['pretrained_model_ecai']
            _ = load_lxmert_qa(path, self.model, label2ans=self.label_to_ans)

        else:
            print(self.cfg)
            # Load finetuned model
            if self.cfg['tiny_lxmert']:
                if self.cfg['oracle']:
                    path = self.cfg['pretrained_model_tiny_lxmert_oracle']
                else:
                    path = self.cfg['pretrained_model_tiny_lxmert']
            else:
                if self.cfg['oracle']:
                    print('Oracle model is only available in tiny version!')
                    exit(0)
                else:
                    path = self.cfg['pretrained_model_lxmert']
            print("Load model's weights from %s" % path)
            state_dict = torch.load("%s.pth" % path,
                                    map_location=torch.device('cpu'))
            for key in list(state_dict.keys()):
                if '.module' in key:
                    state_dict[key.replace('.module',
                                           '')] = state_dict.pop(key)
            self.model.load_state_dict(state_dict, strict=False)

        # To GPU
        # self.model = self.model.cuda()

        print("Model loaded!")
Exemplo n.º 8
0
    def __init__(self, folder="/", load=True):
        # Datasets
        if load:
            self.train_tuple = get_data_tuple(args.train,
                                              bs=args.batch_size,
                                              shuffle=True,
                                              drop_last=True,
                                              folder=folder)
            if args.valid != "":
                self.valid_tuple = get_data_tuple(args.valid,
                                                  bs=512,
                                                  shuffle=False,
                                                  drop_last=False,
                                                  folder=folder)
            else:
                self.valid_tuple = None

        # Model
#         self.model = VQAModel(self.train_tuple.dataset.num_answers)
        self.model = VQAModel(len(self.train_tuple.dataset.label2ans),
                              fn_type=args.fn_type)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Load IndexList of Answer to Type Map
        self.indexlist = json.load(
            open(
                "/data/datasets/vqa_mutant/data/vqa/mutant_l2a/mutant_merge_indexlist.json"
            ))

        print("Length of Masks", len(self.indexlist), flush=True)

        indextensor = torch.cuda.LongTensor(self.indexlist)
        self.mask0 = torch.eq(indextensor, 0).float()
        self.mask1 = torch.eq(indextensor, 1).float()
        self.mask2 = torch.eq(indextensor, 2).float()
        self.mask3 = torch.eq(indextensor, 3).float()

        self.mask_cache = {}

        # Loss and Optimizer

        self.logsoftmax = nn.LogSoftmax()
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax()

        self.bceloss = nn.BCELoss()
        self.nllloss = nn.NLLLoss()

        self.bce_loss = nn.BCEWithLogitsLoss()
        self.ce_loss = nn.CrossEntropyLoss()

        if load:
            if 'bert' in args.optim:
                batch_per_epoch = len(self.train_tuple.loader)
                t_total = int(batch_per_epoch * args.epochs)
                print("BertAdam Total Iters: %d" % t_total)
                from lxrt.optimization import BertAdam
                self.optim = BertAdam(list(self.model.parameters()),
                                      lr=args.lr,
                                      warmup=0.1,
                                      t_total=t_total)
            else:
                self.optim = args.optimizer(self.model.parameters(), args.lr)
            # Output Directory
            self.output = args.output
            os.makedirs(self.output, exist_ok=True)
Exemplo n.º 9
0
    def __init__(self, folder="/", load=True):
        # Datasets
        if load:
            self.train_tuple = get_data_tuple(args.train,
                                              bs=args.batch_size,
                                              shuffle=True,
                                              drop_last=True,
                                              folder=folder)
            if args.valid != "":
                self.valid_tuple = get_data_tuple(args.valid,
                                                  bs=128,
                                                  shuffle=False,
                                                  drop_last=False,
                                                  folder=folder,
                                                  nops=args.nops)
            else:
                self.valid_tuple = None

        # Model
        self.model = VQAModel(self.train_tuple.dataset.num_answers)
        #         is_cp=False
        #         if "vqacpv2" in folder:
        #             is_cp=True
        #         if not is_cp:
        #             self.model = VQAModel(3129)
        #         else:
        #             self.model = VQAModel(16039)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        ans_embed = np.load(
            "/data/datasets/vqa_mutant/data/vqa/mutant_l2a/answer_embs.npy"
        ) + 1e-8
        ans_embed = torch.tensor(ans_embed).cuda()
        self.ans_embed = torch.nn.functional.normalize(ans_embed, dim=1)
        self.embed_cache = {}

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if load:
            if 'bert' in args.optim:
                batch_per_epoch = len(self.train_tuple.loader)
                t_total = int(batch_per_epoch * args.epochs)
                print("BertAdam Total Iters: %d" % t_total)
                from lxrt.optimization import BertAdam
                self.optim = BertAdam(list(self.model.parameters()),
                                      lr=args.lr,
                                      warmup=0.1,
                                      t_total=t_total)
            else:
                self.optim = args.optimizer(self.model.parameters(), args.lr)
            # Output Directory
            self.output = args.output
            os.makedirs(self.output, exist_ok=True)

        self.cos = nn.CosineSimilarity()
Exemplo n.º 10
0
    def __init__(self,
                 args,
                 train_loader=None,
                 val_loader=None,
                 logger=None,
                 num_answers=0,
                 train=True):
        self.args = args
        self.max_text_length = args.max_text_length
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.num_answers = num_answers
        self.logger = logger

        # Model
        self.model = GQAModel.from_pretrained("bert-base-uncased",
                                              args=args,
                                              num_answers=self.num_answers)

        self.verbose = True
        if self.args.distributed:
            if self.args.gpu != 0:
                self.verbose = False

        # Load Checkpoint
        self.start_epoch = None
        if args.load is not None:
            path = args.load + '.pth'
            self.load(path, verbose=self.verbose)

        elif args.load_lxmert_qa is not None:
            path = args.load_lxmert_qa + '_LXRT.pth'
            load_lxmert_qa(
                args,
                path,
                self.model,
                label2ans=self.train_loader.dataset.raw_dataset.label2ans,
                verbose=self.verbose)

        # GPU Options
        print(f'Model Launching at GPU {self.args.gpu}')
        from time import time
        start = time()
        self.model.cuda(args.gpu)

        # Optimizer
        if train:
            self.optim, self.lr_scheduler = self.create_optimizer_and_scheduler(
            )
            self.bce_loss = nn.BCEWithLogitsLoss()

        if args.multiGPU:
            assert args.distributed
            self.model = DDP(self.model,
                             device_ids=[args.gpu],
                             find_unused_parameters=True)

        if args.gpu == 0:
            print(f'It took {time() - start:.1f}s')

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)
    def __init__(self):
        self.train_tuple = get_tuple(args.train,
                                     bs=args.batch_size,
                                     shuffle=True,
                                     drop_last=True)
        if args.valid != "":
            valid_bsize = 2048 if args.multiGPU else 512
            self.valid_tuple = get_tuple(args.valid,
                                         bs=valid_bsize,
                                         shuffle=False,
                                         drop_last=False)
        else:
            self.valid_tuple = None

        self.model = GQAModel(self.train_tuple.dataset.num_answers)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            self.new_ans_label = load_lxmert_qa(
                args.load_lxmert_qa,
                self.model,
                label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Losses and optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.mce_loss = nn.CrossEntropyLoss(ignore_index=-1)
        # self.KL_loss = nn.KLDivLoss(reduction='none')
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(list(self.model.parameters()), args.lr)

        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

        # Tensorboard
        self.boards_dir = os.path.join('boards', self.output)
        if not os.path.exists(self.boards_dir):
            os.makedirs(self.boards_dir)
        self.writerTbrd = SummaryWriter(self.boards_dir)

        # get Glove projection for all answers
        if args.answer_loss == 'glove':
            path_glove = './data/GloVe/GloVeDict.pkl'
            with open(path_glove, 'rb') as f:
                glove_dic = pickle.load(f)
            glove_dim = glove_dic['the'].shape[-1]
            print("Loading Glove%d answer's vector" % glove_dim)
            self.labelans2glove = []
            self.valid_ans_embed = [1] * len(
                self.train_tuple.dataset.label2ans)
            for label, ans in enumerate(self.train_tuple.dataset.label2ans):
                ans = ans.split(' ')
                glove_ans = []
                for w in ans:
                    #print(w)
                    try:
                        glove_ans.append(glove_dic[w])
                    except KeyError:
                        #print('Full ans: %s' % ans)
                        #input(' ')
                        self.valid_ans_embed[label] = 0
                        glove_ans.append(np.zeros(glove_dim))
                #print(glove_ans)
                glove_ans = torch.tensor(glove_ans).mean(-2)
                self.labelans2glove.append(torch.tensor(glove_ans))
            #print(self.labelans2glove)
            print(
                'Ratio of valid ans embedding: %f' %
                (float(sum(self.valid_ans_embed)) / len(self.valid_ans_embed)))
            self.labelans2glove = torch.stack(
                self.labelans2glove).float().cuda()
            self.cosineSim = torch.nn.CosineSimilarity(dim=1, eps=1e-08)
    def __init__(self):
        # Datasets
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True)
        if args.valid != "":
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=1024,
                                              shuffle=False,
                                              drop_last=False)
        else:
            self.valid_tuple = None

        # Model
        self.model = VQAModel(self.train_tuple.dataset.num_answers,
                              finetune_strategy=args.finetune_strategy)

        # if finetune strategy is spottune
        if args.finetune_strategy in PolicyStrategies:
            self.policy_model = PolicyLXRT(
                PolicyStrategies[args.finetune_strategy])

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.finetune_strategy in PolicyStrategies:
            self.policy_model = self.policy_model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()
            self.policy_model.policy_lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Optimizer for policy net
        if args.finetune_strategy in PolicyStrategies:
            self.policy_optim = args.policy_optimizer(
                self.policy_model.parameters(), args.policy_lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)
    def __init__(self, attention=False):
        # Datasets
        print("Fetching data")
        self.train_tuple = get_data_tuple(args.train,
                                          bs=args.batch_size,
                                          shuffle=True,
                                          drop_last=True,
                                          dataset_name="test")
        print("Got data")
        print("fetching val data")
        if args.valid != "":
            self.valid_tuple = get_data_tuple(args.valid,
                                              bs=args.batch_size,
                                              shuffle=False,
                                              drop_last=False,
                                              dataset_name="test")
            print("got data")
        else:
            self.valid_tuple = None
        print("Got data")

        # Model
        print("Making model")
        self.model = VQAModel(self.train_tuple.dataset.num_answers, attention)
        print("Ready model")
        # Print model info:
        print("Num of answers:")
        print(self.train_tuple.dataset.num_answers)
        # print("Model info:")
        # print(self.model)

        # Load pre-trained weights
        if args.load_lxmert is not None:
            self.model.lxrt_encoder.load(args.load_lxmert)
        if args.load_lxmert_qa is not None:
            load_lxmert_qa(args.load_lxmert_qa,
                           self.model,
                           label2ans=self.train_tuple.dataset.label2ans)

        # GPU options
        self.model = self.model.cuda()
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        # Loss and Optimizer
        self.bce_loss = nn.BCEWithLogitsLoss()
        if 'bert' in args.optim:
            batch_per_epoch = len(self.train_tuple.loader)
            t_total = int(batch_per_epoch * args.epochs)
            print("BertAdam Total Iters: %d" % t_total)
            from lxrt.optimization import BertAdam
            self.optim = BertAdam(list(self.model.parameters()),
                                  lr=args.lr,
                                  warmup=0.1,
                                  t_total=t_total)
        else:
            self.optim = args.optimizer(self.model.parameters(), args.lr)

        # Output Directory
        self.output = args.output
        os.makedirs(self.output, exist_ok=True)