def __init__(self): # Datasets self.train_tuple = get_data_tuple(args.train, bs=args.batch_size, shuffle=True, drop_last=True, dataset_type=args.dataset_type) if args.valid != "": self.valid_tuple = get_data_tuple(args.valid, bs=128, shuffle=False, drop_last=False, dataset_type=args.dataset_type) else: self.valid_tuple = None # Model self.model = VQAModel( self.train_tuple.dataset.num_answers if not args.transfer_learning else VQADataset.get_answers_number(), encoder_type=args.encoder_type) # Load pre-trained weights if args.load_lxmert is not None: self.model.lxrt_encoder.load(args.load_lxmert) if args.load_lxmert_qa is not None: load_lxmert_qa(args.load_lxmert_qa, self.model, label2ans=self.train_tuple.dataset.label2ans) self.prepare_model() # Output Directory self.output = args.output os.makedirs(self.output, exist_ok=True)
def get_data_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple: dset = VQADataset(splits) tset = VQATorchDataset(dset) evaluator = VQAEvaluator(dset) data_loader = DataLoader( tset, batch_size=bs, shuffle=shuffle, num_workers=args.num_workers, drop_last=drop_last, pin_memory=True ) return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator)