Exemplo n.º 1
0
    def prepare_train_data(self):
        label_dset = NLVR2Dataset(splits=args.train,
                                  fraction=args.mixmatch_label_fraction)
        unlabel_dset = NLVR2Dataset(splits=args.train,
                                    fraction=args.mixmatch_label_fraction,
                                    remove_labels=True,
                                    keep_reverse=True)
        print("Final Labelled / Unlabelled Split = %d / %d" %
              (len(label_dset), len(unlabel_dset)))
        self.img_data = []
        label_tset = NLVR2TorchDataset(label_dset, img_data=self.img_data)
        unlabel_tset = NLVR2TorchDataset(unlabel_dset, img_data=self.img_data)

        label_data_loader = DataLoader(label_tset,
                                       batch_size=args.batch_size // 2,
                                       shuffle=True,
                                       num_workers=args.num_workers,
                                       drop_last=True,
                                       pin_memory=True)
        unlabel_data_loader = DataLoader(unlabel_tset,
                                         batch_size=args.batch_size // 2,
                                         shuffle=True,
                                         num_workers=args.num_workers,
                                         drop_last=True,
                                         pin_memory=True)

        evaluator = NLVR2Evaluator(label_dset)

        self.train_tuple = MixMatchDataTuple(
            dataset=label_dset,
            loader=label_data_loader,
            unlabel_dataset=unlabel_dset,
            unlabel_loader=unlabel_data_loader,
            evaluator=evaluator)
Exemplo n.º 2
0
def get_tuple(splits: str, bs:int, shuffle=False, drop_last=False) -> DataTuple:
    dset = NLVR2Dataset(splits)
    tset = NLVR2TorchDataset(dset)
    evaluator = NLVR2Evaluator(dset)
    data_loader = DataLoader(
        tset, batch_size=bs,
        shuffle=shuffle, num_workers=args.num_workers,
        drop_last=drop_last, pin_memory=True
    )

    return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator)
Exemplo n.º 3
0
    def reinitialize(self):
        # Update train_tuple
        dset = self.updated_dataset_for_self_train()
        tset = NLVR2TorchDataset(dset, img_data=self.img_data)
        evaluator = NLVR2Evaluator(dset)
        data_loader = DataLoader(tset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 drop_last=True,
                                 pin_memory=True)
        self.train_tuple = DataTuple(dataset=dset,
                                     loader=data_loader,
                                     evaluator=evaluator)

        del self.model
        self.setup_model()
        self.setup_optimizers()
Exemplo n.º 4
0
def get_tuple(splits: str,
              bs: int,
              shuffle=False,
              drop_last=False,
              fraction=None,
              img_data=None) -> DataTuple:
    dset = NLVR2Dataset(splits, fraction=fraction)
    tset = NLVR2TorchDataset(dset, img_data=img_data)
    evaluator = NLVR2Evaluator(dset)
    data_loader = DataLoader(tset,
                             batch_size=bs,
                             shuffle=shuffle,
                             num_workers=args.num_workers,
                             drop_last=drop_last,
                             pin_memory=True)
    print("Final length of splits %s = %d" % (splits, len(tset)))

    return DataTuple(dataset=dset, loader=data_loader, evaluator=evaluator)