Exemplo n.º 1
0
def get_train_val_dataloader(args):
    dataset_list = args.name
    assert isinstance(dataset_list, list)
    datasets = []
    val_datasets = []
    limit = args.limit
    for dname in dataset_list:
        if args.type == 'jigsaw':
            img_transformer, tile_transformer = get_jig_train_transformers(args)
            train_dataset = JigsawDataset(dname, split='train', val_size=args.val_size,
                    img_transformer=img_transformer, tile_transformer=tile_transformer,
                    jig_classes=args.aux_classes, bias_whole_image=args.bias_whole_image)
            val_dataset = JigsawTestDataset(dname, split='val', val_size=args.val_size,
                img_transformer=get_val_transformer(args), jig_classes=args.aux_classes)
        elif args.type == 'rotate':
            img_transformer = get_rot_train_transformers(args)
            train_dataset = RotateDataset(dname, split='train', val_size=args.val_size,
                    img_transformer=img_transformer, rot_classes=args.aux_classes, bias_whole_image=args.bias_whole_image)
            val_dataset = RotateTestDataset(dname, split='val', val_size=args.val_size,
                img_transformer=get_val_transformer(args), rot_classes=args.aux_classes)

        if limit:
            train_dataset = Subset(train_dataset, limit)

        datasets.append(train_dataset)
        val_datasets.append(val_dataset)

    dataset = ConcatDataset(datasets)
    val_dataset = ConcatDataset(val_datasets)
    loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)
    return loader, val_loader
Exemplo n.º 2
0
def get_train_dataloader(args, patches):
    dataset_list = args.source
    assert isinstance(dataset_list, list)
    train_datasets = []
    val_datasets = []
    img_transformer, tile_transformer = get_train_transformers(args)
    limit = args.limit_source
    for dname in dataset_list:
        name_train, name_val, labels_train, labels_val = get_split_dataset_info(
            join(dirname(__file__), 'txt_lists','%s_train.txt' % dname),
            args.val_size
        )
        train_dataset = JigsawDataset(name_train, labels_train, patches=patches,
                                      img_transformer=img_transformer,
                                      tile_transformer=tile_transformer,
                                      jig_classes=args.jigsaw_n_classes,
                                      bias_whole_image=args.bias_whole_image)
        if limit:
            train_dataset = Subset(train_dataset, limit)
        train_datasets.append(train_dataset)

        # Validation test => subtracted from train split
        val_datasets.append(
            JigsawTestDataset(name_val, labels_val, img_transformer=get_val_transformer(args),
                              patches=patches, jig_classes=args.jigsaw_n_classes))

    train_dataset = ConcatDataset(train_datasets)
    val_dataset = ConcatDataset(val_datasets)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)
    return train_loader, val_loader
def get_digital_train_dataloader(args, dname):
    images, labels = None, []
    if dname == 'mnist':
        images, labels = load_mnist_dataset(split='train', flip_p=args.flip_p)

    dataset_digit = AdvDataset(images, labels)

    train_size, val_size = 0, 0
    if args.limit_source and len(dataset_digit) > args.limit_source:
        train_size = args.limit_source
    else:
        train_size = (int)(len(dataset_digit) * (1 - args.val_size))

    val_size = len(dataset_digit) - train_size
    train_set, val_set = torch.utils.data.random_split(dataset_digit,
                                                       [train_size, val_size])
    if val_size > (int)(train_size * args.val_size / (1 - args.val_size)):
        val_set = Subset(val_set, (int)(train_size * args.val_size /
                                        (1 - args.val_size)))

    train_set, val_set = ConcatDataset([train_set]), ConcatDataset([val_set])
    loader = torch.utils.data.DataLoader(train_set,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=4,
                                             pin_memory=True,
                                             drop_last=False)
    return loader, val_loader
Exemplo n.º 4
0
def get_train_dataloader(args, patches):
    dataset_list = args.source
    assert isinstance(dataset_list, list)

    datasets = []
    val_datasets = []
    img_transformer, tile_transformer = get_train_transformers(args)
    limit = args.limit_source
    for dname in dataset_list:
        if dname in new_nex_datasets:
            index_root = data_root = '/import/home/share/from_Nexperia_April2021/%s' % dname
        else:
            index_root = join(dirname(__file__),'correct_txt_lists')
            data_root = join(dirname(__file__),'kfold')
        name_train, labels_train = _dataset_info(join(index_root, "%s_train.txt" % dname))
        name_val, labels_val = _dataset_info(join(index_root, "%s_val.txt" % dname))
        train_dataset = JigsawNewDataset(data_root, name_train, labels_train, patches=patches, img_transformer=img_transformer,
                            tile_transformer=tile_transformer, jig_classes=30, bias_whole_image=args.bias_whole_image)
        if limit:
            train_dataset = Subset(train_dataset, limit)
        datasets.append(train_dataset)
        val_datasets.append(JigsawTestNewDataset(data_root,name_val, labels_val, img_transformer=get_val_transformer(args),patches=patches, jig_classes=30)) 

    dataset = ConcatDataset(datasets)
    val_dataset = ConcatDataset(val_datasets)
    loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)
    return loader, val_loader
Exemplo n.º 5
0
    def load_datasets(self, args, gpu, n_gpus):
        for (framework, language), dataset in self.child_datasets.items():
            dataset.load_dataset(args, gpu, n_gpus, framework, language)

        self.share_chars()
        self.share_vocabs(args)

        train_datasets = [
            self.child_datasets[self.id_to_framework[i]].train
            for i in range(len(self.child_datasets))
        ]
        self.train = torch.utils.data.DataLoader(ConcatDataset(train_datasets),
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 num_workers=args.workers,
                                                 collate_fn=Collate(),
                                                 pin_memory=True,
                                                 drop_last=True)
        self.train_size = len(self.train.dataset)
        self.mean_label_length = sum(
            dataset.node_count
            for dataset in self.child_datasets.values()) / self.train_size

        val_datasets = [
            self.child_datasets[self.id_to_framework[i]].val
            for i in range(len(self.child_datasets))
        ]
        self.val = torch.utils.data.DataLoader(
            ConcatDataset(val_datasets),
            batch_size=1,
            shuffle=False,
            num_workers=args.workers,
            collate_fn=Collate(),
            pin_memory=True,
        )
        self.val_size = len(self.val.dataset)

        test_datasets = [
            self.child_datasets[self.id_to_framework[i]].test
            for i in range(len(self.child_datasets))
        ]
        self.test = torch.utils.data.DataLoader(
            ConcatDataset(test_datasets),
            batch_size=1,
            shuffle=False,
            num_workers=args.workers,
            collate_fn=Collate(),
            pin_memory=True,
        )
        self.test_size = len(self.test.dataset)

        if gpu == 0:
            batch = next(iter(self.train))
            print(f"\nBatch content: {Batch.to_str(batch)}\n")
            print(flush=True)
Exemplo n.º 6
0
def get_train_dataloader(args):

    dataset_list = args.source
    assert isinstance(dataset_list, list)

    datasets = []
    val_datasets = []
    img_transformer = get_train_transformers(args)
    val_trasformer = get_val_transformer(args)

    for dname in dataset_list:
        name_train, name_val, labels_train, labels_val = get_split_dataset_info(
            join(dirname(__file__), 'txt_lists', dname + '.txt'),
            args.val_size)
        #batch of cropped and shuffle images for train args.betJigen
        train_dataset = Dataset(name_train,
                                labels_train,
                                args.path_dataset,
                                img_transformer=img_transformer,
                                betaJigen=args.betaJigen,
                                rotation=args.rotation,
                                oddOneOut=args.oddOneOut)
        datasets.append(train_dataset)

        val_dataset = TestDataset(name_val,
                                  labels_val,
                                  args.path_dataset,
                                  img_transformer=val_trasformer,
                                  betaJigen=args.betaJigen,
                                  rotation=args.rotation,
                                  oddOneOut=args.oddOneOut)
        val_datasets.append(val_dataset)

    dataset = ConcatDataset(datasets)
    val_dataset = ConcatDataset(val_datasets)

    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=4,
                                             pin_memory=True,
                                             drop_last=False)

    return loader, val_loader
Exemplo n.º 7
0
def get_train_dataloader(args, patches):
    dataset_list = args.source
    assert isinstance(dataset_list, list)
    datasets = []
    val_datasets = []
    img_transformer, tile_transformer = get_train_transformers(args)
    limit = args.limit_source
    for dname in dataset_list:
        # name_train, name_val, labels_train, labels_val = get_split_dataset_info(join(dirname(__file__), 'txt_lists', '%s_train.txt' % dname), args.val_size)
        name_train, labels_train = _dataset_info(
            join(dirname(__file__), 'correct_txt_lists',
                 '%s_train_kfold.txt' % dname))
        name_val, labels_val = _dataset_info(
            join(dirname(__file__), 'correct_txt_lists',
                 '%s_crossval_kfold.txt' % dname))

        train_dataset = JigsawNewDataset(
            name_train,
            labels_train,
            patches=patches,
            img_transformer=img_transformer,
            tile_transformer=tile_transformer,
            jig_classes=30,
            bias_whole_image=args.bias_whole_image)
        if limit:
            train_dataset = Subset(train_dataset, limit)
        datasets.append(train_dataset)
        val_datasets.append(
            JigsawTestNewDataset(name_val,
                                 labels_val,
                                 img_transformer=get_val_transformer(args),
                                 patches=patches,
                                 jig_classes=30))
    dataset = ConcatDataset(datasets)
    val_dataset = ConcatDataset(val_datasets)
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=4,
                                             pin_memory=True,
                                             drop_last=False)
    return loader, val_loader
Exemplo n.º 8
0
def get_val_dataloader(args, patches=False):
    if args.stylized:
        names, labels = _dataset_info(
            join(dirname(__file__), 'txt_lists', 'Stylized' + args.dataset,
                 "{}_target".format(args.target), '%s_test.txt' % args.target))
    else:
        names, labels = _dataset_info(
            join(dirname(__file__), 'txt_lists', 'Vanilla' + args.dataset,
                 '%s_test.txt' % args.target))

    img_tr = get_val_transformer(args)
    val_dataset = JigsawTestDataset(names,
                                    labels,
                                    patches=patches,
                                    img_transformer=img_tr,
                                    jig_classes=args.jigsaw_n_classes)
    if args.limit_target and len(val_dataset) > args.limit_target:
        val_dataset = Subset(val_dataset, args.limit_target)
        print("Using %d subset of val dataset" % args.limit_target)
    dataset = ConcatDataset([val_dataset])
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=False)
    return loader
Exemplo n.º 9
0
def get_val_dataloader(args, patches=False):
    dname = args.target
    if dname in new_nex_datasets:
        index_root = data_root = '/import/home/share/from_Nexperia_April2021/%s' % dname
    else:
        index_root = join(dirname(__file__), 'correct_txt_lists')
        data_root = join(dirname(__file__), 'kfold')
    names, labels = _dataset_info(join(index_root, "%s_val.txt" % dname))
    img_tr = get_nex_val_transformer(args)
    val_dataset = JigsawTestNewDataset(data_root,
                                       names,
                                       labels,
                                       patches=patches,
                                       img_transformer=img_tr,
                                       jig_classes=30)
    if args.limit_target and len(val_dataset) > args.limit_target:
        val_dataset = Subset(val_dataset, args.limit_target)
        print("Using %d subset of val dataset" % args.limit_target)
    dataset = ConcatDataset([val_dataset])
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=False)
    return loader
Exemplo n.º 10
0
def get_jigsaw_dataloader(args):

    #Only for DA
    names, labels = _dataset_info(
        join(dirname(__file__), 'txt_lists', args.target + '.txt'))
    img_tr = get_train_transformers(args)

    train_dataset = Dataset(names,
                            labels,
                            args.path_dataset,
                            img_transformer=img_tr,
                            beta_scrambled=args.beta_scrambled,
                            beta_rotated=args.beta_rotated,
                            beta_odd=args.beta_odd,
                            rotation=False,
                            odd=False)
    dataset = ConcatDataset([train_dataset])
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=False)

    return loader
Exemplo n.º 11
0
def get_tgt_dataloader(args, patches=False):
    '''
    Load whole domain dataset
    '''
    img_tr = get_nex_val_transformer(args)
    dname = args.target
    if dname in new_nex_datasets:
        index_root = data_root = '/import/home/share/from_Nexperia_April2021/%s' % dname
    else:
        index_root = join(dirname(__file__),'correct_txt_lists')
        data_root = join(dirname(__file__),'kfold')
    if args.downsample_target:
        name_train, labels_train = _dataset_info(join(index_root,"%s_train_down.txt" % dname))
        name_val, labels_val = _dataset_info(join(index_root, "%s_val_down.txt" % dname))
        name_test, labels_test = _dataset_info(join(index_root, "%s_test_down.txt" % dname))
    else:
        name_train, labels_train = _dataset_info(join(index_root,"%s_train.txt" % dname))
        name_val, labels_val = _dataset_info(join(index_root, "%s_val.txt" % dname))
        name_test, labels_test = _dataset_info(join(index_root, "%s_test.txt" % dname))

    tgt_train_dataset = JigsawTestNewDataset(data_root, name_train, labels_train, patches=patches, img_transformer=img_tr,jig_classes=30)
    tgt_val_dataset = JigsawTestNewDataset(data_root, name_val, labels_val, patches=patches, img_transformer=img_tr,
                                            jig_classes=30)
    tgt_test_dataset = JigsawTestNewDataset(data_root, name_test, labels_test, patches=patches, img_transformer=img_tr,
                                            jig_classes=30)

    tgt_dataset = ConcatDataset([tgt_train_dataset, tgt_val_dataset, tgt_test_dataset])
    loader = torch.utils.data.DataLoader(tgt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=False)
    return loader
def get_digital_target_dataloader(args, dname):
    images, labels = None, []
    if dname == 'mnist':
        images, labels = load_mnist_dataset(split='test')
    elif dname == 'svhn':
        images, labels = load_svhn_dataset(split='test')
    elif dname == 'usps':
        images, labels = load_usps_dataset(split='test')
    elif dname == 'mnist_m':
        images, labels = load_mnist_m_dataset(split='test')
    elif dname == 'syn':
        images, labels = load_syn_dataset(split='test')

    dataset_digit = AdvDataset(images, labels)
    if args.limit_target and len(dataset_digit) > args.limit_target:
        dataset_digit = Subset(dataset_digit, args.limit_target)

    dataset = ConcatDataset([dataset_digit])
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=False)
    return loader
Exemplo n.º 13
0
def get_par_val_dataloader(args, patches=False):
    names, labels = _dataset_info(
        join(dirname(__file__), 'txt_lists', '%s_test.txt' % args.target))
    img_tr = [transforms.Resize((args.image_size, args.image_size))]
    tile_tr = [
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
    img_transformer = transforms.Compose(img_tr)
    tile_transformer = transforms.Compose(tile_tr)
    val_dataset = PARDataset(names,
                             labels,
                             patches=patches,
                             img_transformer=img_transformer,
                             tile_transformer=tile_transformer)
    if args.limit_target and len(val_dataset) > args.limit_target:
        val_dataset = Subset(val_dataset, args.limit_target)
        print("Using %d subset of val dataset" % args.limit_target)
    dataset = ConcatDataset([val_dataset])
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=False)
    return loader
Exemplo n.º 14
0
def get_val_dataloader(args, patches=False):
    if args.target in pacs_datasets:
        names, labels = _dataset_info(
            join(dirname(__file__), 'txt_lists', '%s_test.txt' % args.target))
    elif args.target in imagenet_datasets:
        names, labels = _dataset_info(
            join(
                join(
                    join(dirname(__file__), 'txt_lists',
                         'imagenet_testDataPath.txt'))))
    else:
        print('Error: test dataset not found.')
    img_tr = get_val_transformer(args)
    val_dataset = PARTestDataset(names,
                                 labels,
                                 patches=patches,
                                 img_transformer=img_tr)
    if args.limit_target and len(val_dataset) > args.limit_target:
        val_dataset = Subset(val_dataset, args.limit_target)
        print("Using %d subset of val dataset" % args.limit_target)
    dataset = ConcatDataset([val_dataset])
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=False)
    return loader
Exemplo n.º 15
0
def get_single_test_dataloader(args, dname, patches=False):
    names, labels = _dataset_info(join(dirname(__file__), 'txt_lists', '%s_test.txt' % dname))
    img_tr = get_val_transformer(args)
    
    # JigsawTestDataset return [unsorted_image, permutation_order-0, class_label]
    test_dataset = JigsawTestDataset(names, labels, patches=patches, img_transformer=img_tr, jig_classes=args.jigsaw_n_classes)
    if args.limit_target and len(test_dataset) > args.limit_target:
        test_dataset = Subset(test_dataset, args.limit_target)
        print("Using %d subset of test dataset" % args.limit_target)
    dataset = ConcatDataset([test_dataset])
    loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)
    return loader
Exemplo n.º 16
0
    def load_sentences(self, sentences, args, framework: str, language: str):
        def switch(f, l, s):
            return s if (framework == f and language == l) else []

        datasets = [
            dataset.load_sentences(switch(f, l, sentences), args, language)
            for (f, l), dataset in self.child_datasets.items()
        ]
        return torch.utils.data.DataLoader(ConcatDataset(datasets),
                                           batch_size=1,
                                           shuffle=False,
                                           collate_fn=Collate())
def get_train_dataloader(args):
    dataset_list = args.source
    assert isinstance(dataset_list, list)
    datasets = []
    val_datasets = []
    img_transformer = get_train_transformers(args)
    limit = args.limit_source

    for dname in dataset_list:
        if dname in digits_datasets:
            return get_digital_train_dataloader(args, dname)
        name_train, name_val, labels_train, labels_val = get_split_dataset_info(
            join(dirname(__file__), 'txt_lists', '%s_train.txt' % dname),
            args.val_size)
        train_dataset = JigsawDataset(name_train,
                                      labels_train,
                                      img_transformer=img_transformer)
        if limit:
            train_dataset = Subset(train_dataset, limit)
        datasets.append(train_dataset)
        val_datasets.append(
            JigsawDataset(name_val,
                          labels_val,
                          img_transformer=get_val_transformer(args)))
    dataset = ConcatDataset(datasets)
    val_dataset = ConcatDataset(val_datasets)
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=4,
                                             pin_memory=True,
                                             drop_last=False)
    return loader, val_loader
def append_adversarial_samples(args, data_loader, adv_data, adv_labels):
    datasets = data_loader.dataset.datasets

    dataset_adv = AdvDataset(adv_data, adv_labels)
    datasets.append(dataset_adv)

    dataset = ConcatDataset(datasets)
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=True)
    return loader
Exemplo n.º 19
0
def get_val_dataloader(args):
    names, labels = _dataset_info(
        join(dirname(__file__), 'txt_lists', '%s_test.txt' % args.target))
    img_tr = get_val_transformer(args)
    val_dataset = BaselineDataset(names, labels, img_transformer=img_tr)
    if args.limit_target and len(val_dataset) > args.limit_target:
        val_dataset = Subset(val_dataset, args.limit_target)
        print("Using %d subset of val dataset" % args.limit_target)
    dataset = ConcatDataset([val_dataset])
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=False)
    return loader
Exemplo n.º 20
0
def get_test_dataloader(args):
    img_tr = get_val_transformer(args)
    name = args.name
    if args.type == 'jigsaw':
        val_dataset = JigsawTestDataset(name, split='test',
                img_transformer=img_tr, jig_classes=args.aux_classes)
    elif args.type == 'rotate':
        val_dataset = RotateTestDataset(name, split='test',
                img_transformer=img_tr, rot_classes=args.aux_classes)

    if args.limit and len(val_dataset) > args.limit:
        val_dataset = Subset(val_dataset, args.limit)
        print("Using %d subset of dataset" % args.limit)
    dataset = ConcatDataset([val_dataset])
    loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)
    return loader
Exemplo n.º 21
0
def get_val_dataloader(args):

    names, labels = _dataset_info(
        join(dirname(__file__), 'txt_lists', args.target + '.txt'))
    img_tr = get_val_transformer(args)

    val_dataset = TestDataset(names,
                              labels,
                              args.path_dataset,
                              img_transformer=img_tr)
    dataset = ConcatDataset([val_dataset])
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=False)

    return loader
Exemplo n.º 22
0
def get_test_dataloader(args):

    name = args.name
    mode = args.get('mode', 'RGB')
    loaders = []
    img_trs = get_multi_crop_transformers(args)

    for img_tr in img_trs:
        if args.type == 'jigsaw':
            val_dataset = JigsawTestDataset(name,
                                            split='test',
                                            img_transformer=img_tr,
                                            jig_classes=args.aux_classes)
        elif args.type == 'rotate':
            val_dataset = RotateTestDataset(name,
                                            split='test',
                                            img_transformer=img_tr,
                                            rot_classes=args.aux_classes,
                                            mode=mode)
        elif args.type == 'image':
            val_dataset = ImageTestDataset(name,
                                           split='test',
                                           img_transformer=img_tr,
                                           mode=mode)

        if args.limit and len(val_dataset) > args.limit:
            val_dataset = Subset(val_dataset, args.limit)
            print("Using %d subset of dataset" % args.limit)
        dataset = ConcatDataset([val_dataset])
        loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=4,
                                             pin_memory=True,
                                             drop_last=False)
        if args.get('multi_crop', False):
            loaders.append(loader)
        else:
            return loader

    return loaders
Exemplo n.º 23
0
def get_trainTargetAsSource_dataloader(args):
    #used to create dataset of target images ready for jigsaw task ( only in training da!!)
    names, labels = _dataset_info(
        join(dirname(__file__), 'txt_lists', args.target + '.txt'))
    img_transformer = get_train_transformers(args)
    train_dataset = Dataset(names,
                            labels,
                            args.path_dataset,
                            img_transformer=img_transformer,
                            betaJigen=args.betaJigen,
                            rotation=args.rotation,
                            oddOneOut=args.oddOneOut)
    #val_dataset = TestDataset(names, labels,args.path_dataset, img_transformer=img_tr,betaJigen = args.betaJigen)
    dataset = ConcatDataset([train_dataset])
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=True)

    return loader
Exemplo n.º 24
0
def get_jigsaw_val_dataloader(args, patches=False):
    if args.stylized:
        names, labels = _dataset_info(
            join(dirname(__file__), 'txt_lists', 'Stylized' + args.dataset,
                 "{}_target".format(args.target), '%s_test.txt' % args.target))
    else:
        names, labels = _dataset_info(
            join(dirname(__file__), 'txt_lists', 'Vanilla' + args.dataset,
                 '%s_test.txt' % args.target))

    img_tr = [transforms.Resize((args.image_size, args.image_size))]
    tile_tr = [
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
    img_transformer = transforms.Compose(img_tr)
    tile_transformer = transforms.Compose(tile_tr)
    val_dataset = JigsawDataset(names,
                                labels,
                                patches=patches,
                                img_transformer=img_transformer,
                                tile_transformer=tile_transformer,
                                jig_classes=args.jigsaw_n_classes,
                                bias_whole_image=args.bias_whole_image,
                                grid_size=args.grid_size)
    if args.limit_target and len(val_dataset) > args.limit_target:
        val_dataset = Subset(val_dataset, args.limit_target)
        print("Using %d subset of val dataset" % args.limit_target)
    dataset = ConcatDataset([val_dataset])
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=False)
    return loader
Exemplo n.º 25
0
def get_train_dataloader(args, patches):
    dataset_list = args.source
    assert isinstance(dataset_list, list)
    datasets = []
    val_datasets = []
    img_transformer, tile_transformer = get_train_transformers(args)
    limit = args.limit_source
    for dname in dataset_list:
        if args.stylized:
            name_train, name_val, labels_train, labels_val = get_split_dataset_info(
                join(dirname(__file__), 'txt_lists', 'Stylized' + args.dataset,
                     "{}_target".format(args.target), '%s_train.txt' % dname),
                args.val_size)
            # print(name_train)

        else:
            name_train, name_val, labels_train, labels_val = get_split_dataset_info(
                join(dirname(__file__), 'txt_lists', 'Vanilla' + args.dataset,
                     '%s_train.txt' % dname), args.val_size)

        train_dataset = JigsawDataset(name_train,
                                      labels_train,
                                      patches=patches,
                                      img_transformer=img_transformer,
                                      tile_transformer=tile_transformer,
                                      jig_classes=args.jigsaw_n_classes,
                                      bias_whole_image=args.bias_whole_image,
                                      grid_size=args.grid_size)

        if limit:
            train_dataset = Subset(train_dataset, limit)
        datasets.append(train_dataset)
        if args.jig_only:
            val_datasets.append(
                JigsawDataset(name_val,
                              labels_val,
                              patches=patches,
                              img_transformer=img_transformer,
                              tile_transformer=tile_transformer,
                              jig_classes=args.jigsaw_n_classes,
                              bias_whole_image=args.bias_whole_image,
                              grid_size=args.grid_size))
        else:
            val_datasets.append(
                JigsawTestDataset(name_val,
                                  labels_val,
                                  img_transformer=get_val_transformer(args),
                                  patches=patches,
                                  jig_classes=args.jigsaw_n_classes))

        #val_datasets.append(JigsawTestDataset(name_val, labels_val, img_transformer=get_val_transformer(args),
        #   patches=patches, jig_classes=args.jigsaw_n_classes))
    dataset = ConcatDataset(datasets)
    val_dataset = ConcatDataset(val_datasets)
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=4,
                                             pin_memory=True,
                                             drop_last=False)
    return loader, val_loader
def get_train_dataloader_JiGen(args, device, task="DG"):

    dataset_list = args.source
    assert isinstance(dataset_list, list)

    target = args.target

    if task == "DA":
        if target in dataset_list:
            dataset_list.remove(target)

    datasets = []
    val_datasets = []
    img_transformer, patch_transformer = get_train_transformers(args)
    val_trasformer = get_val_transformer(args)

    for dname in dataset_list:
        name_train, name_val, labels_train, labels_val = get_split_dataset_info(
            join(dirname(__file__), 'txt_lists', dname + '.txt'),
            args.val_size)

        if target == dname:
            name_train += name_val
            labels_train += labels_val

        train_dataset = JigsawDataset(name_train,
                                      labels_train,
                                      args.path_dataset,
                                      args.scrambled,
                                      args.rotated,
                                      args.odd,
                                      args.jigen_transf,
                                      args.grid_size,
                                      args.permutation_number,
                                      device,
                                      task,
                                      target_name=args.target,
                                      img_transformer=img_transformer,
                                      patch_transformer=patch_transformer)
        datasets.append(train_dataset)

        if target != dname:
            val_dataset = TestDataset(name_val,
                                      labels_val,
                                      args.path_dataset,
                                      img_transformer=val_trasformer)
            val_datasets.append(val_dataset)

    dataset = ConcatDataset(datasets)
    val_dataset = ConcatDataset(val_datasets)

    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=0,
                                         pin_memory=True,
                                         drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=0,
                                             pin_memory=True,
                                             drop_last=False)

    return loader, val_loader
Exemplo n.º 27
0
def get_train_dataloader(args, patches):
    dataset_list = args.source
    assert isinstance(dataset_list, list)
    datasets = []
    val_datasets = []
    img_transformer, tile_transformer = get_train_transformers(args)
    limit = args.limit_source
    if dataset_list[0] in pacs_datasets:
        for dname in dataset_list:
            name_train, name_val, labels_train, labels_val = get_split_dataset_info(
                join(dirname(__file__), 'txt_lists', '%s_train.txt' % dname),
                args.val_size)
            train_dataset = PARDataset(name_train,
                                       labels_train,
                                       patches=patches,
                                       img_transformer=img_transformer,
                                       tile_transformer=tile_transformer)
            if limit:
                train_dataset = Subset(train_dataset, limit)
            datasets.append(train_dataset)
            val_datasets.append(
                PARTestDataset(name_val,
                               labels_val,
                               img_transformer=get_val_transformer(args),
                               patches=patches))
    elif dataset_list[0] in imagenet_datasets:
        name_train, labels_train = _dataset_info(
            join(
                join(dirname(__file__), 'txt_lists',
                     'imagenet_trainDataPath.txt')))
        train_dataset = PARDataset(name_train,
                                   labels_train,
                                   patches=patches,
                                   img_transformer=img_transformer,
                                   tile_transformer=tile_transformer)
        datasets.append(train_dataset)
        name_val, labels_val = _dataset_info(
            join(
                join(dirname(__file__), 'txt_lists',
                     'imagenet_valDataPath.txt')))
        val_datasets.append(
            PARTestDataset(name_val,
                           labels_val,
                           img_transformer=get_val_transformer(args),
                           patches=patches))
    else:
        print('Error: dataset not found.')
    dataset = ConcatDataset(datasets)
    val_dataset = ConcatDataset(val_datasets)
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=4,
                                             pin_memory=True,
                                             drop_last=False)
    return loader, val_loader