Example #1
0
def get_val_dataloader(args, patches=False):
    if args.stylized:
        names, labels = _dataset_info(
            join(dirname(__file__), 'txt_lists', 'Stylized' + args.dataset,
                 "{}_target".format(args.target), '%s_test.txt' % args.target))
    else:
        names, labels = _dataset_info(
            join(dirname(__file__), 'txt_lists', 'Vanilla' + args.dataset,
                 '%s_test.txt' % args.target))

    img_tr = get_val_transformer(args)
    val_dataset = JigsawTestDataset(names,
                                    labels,
                                    patches=patches,
                                    img_transformer=img_tr,
                                    jig_classes=args.jigsaw_n_classes)
    if args.limit_target and len(val_dataset) > args.limit_target:
        val_dataset = Subset(val_dataset, args.limit_target)
        print("Using %d subset of val dataset" % args.limit_target)
    dataset = ConcatDataset([val_dataset])
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=False)
    return loader
Example #2
0
def get_train_dataloader(args, patches):
    dataset_list = args.source
    assert isinstance(dataset_list, list)

    datasets = []
    val_datasets = []
    img_transformer, tile_transformer = get_train_transformers(args)
    limit = args.limit_source
    for dname in dataset_list:
        if dname in new_nex_datasets:
            index_root = data_root = '/import/home/share/from_Nexperia_April2021/%s' % dname
        else:
            index_root = join(dirname(__file__),'correct_txt_lists')
            data_root = join(dirname(__file__),'kfold')
        name_train, labels_train = _dataset_info(join(index_root, "%s_train.txt" % dname))
        name_val, labels_val = _dataset_info(join(index_root, "%s_val.txt" % dname))
        train_dataset = JigsawNewDataset(data_root, name_train, labels_train, patches=patches, img_transformer=img_transformer,
                            tile_transformer=tile_transformer, jig_classes=30, bias_whole_image=args.bias_whole_image)
        if limit:
            train_dataset = Subset(train_dataset, limit)
        datasets.append(train_dataset)
        val_datasets.append(JigsawTestNewDataset(data_root,name_val, labels_val, img_transformer=get_val_transformer(args),patches=patches, jig_classes=30)) 

    dataset = ConcatDataset(datasets)
    val_dataset = ConcatDataset(val_datasets)
    loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)
    return loader, val_loader
Example #3
0
def get_train_dataloader(args, patches):
    dataset_list = args.source
    assert isinstance(dataset_list, list)
    datasets = []
    val_datasets = []
    img_transformer, tile_transformer = get_train_transformers(args)
    limit = args.limit_source
    for dname in dataset_list:
        # name_train, name_val, labels_train, labels_val = get_split_dataset_info(join(dirname(__file__), 'txt_lists', '%s_train.txt' % dname), args.val_size)
        name_train, labels_train = _dataset_info(
            join(dirname(__file__), 'correct_txt_lists',
                 '%s_train_kfold.txt' % dname))
        name_val, labels_val = _dataset_info(
            join(dirname(__file__), 'correct_txt_lists',
                 '%s_crossval_kfold.txt' % dname))

        train_dataset = JigsawNewDataset(
            name_train,
            labels_train,
            patches=patches,
            img_transformer=img_transformer,
            tile_transformer=tile_transformer,
            jig_classes=args.jigsaw_n_classes,
            bias_whole_image=args.bias_whole_image)
        if limit:
            train_dataset = Subset(train_dataset, limit)
        datasets.append(train_dataset)
        val_datasets.append(
            JigsawTestNewDataset(name_val,
                                 labels_val,
                                 img_transformer=get_val_transformer(args),
                                 patches=patches,
                                 jig_classes=args.jigsaw_n_classes))
    dataset = ConcatDataset(datasets)
    val_dataset = ConcatDataset(val_datasets)
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=4,
                                             pin_memory=True,
                                             drop_last=False)
    return loader, val_loader
Example #4
0
def get_jigsaw_val_dataloader(args, patches=False):
    names, labels = _dataset_info(
        join(dirname(__file__), 'txt_lists', '%s_test.txt' % args.target))
    img_tr = [transforms.Resize((args.image_size, args.image_size))]
    tile_tr = [
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
    img_transformer = transforms.Compose(img_tr)
    tile_transformer = transforms.Compose(tile_tr)
    val_dataset = JigsawDataset(names,
                                labels,
                                patches=patches,
                                img_transformer=img_transformer,
                                tile_transformer=tile_transformer,
                                jig_classes=args.jigsaw_n_classes,
                                bias_whole_image=args.bias_whole_image)
    if args.limit_target and len(val_dataset) > args.limit_target:
        val_dataset = Subset(val_dataset, args.limit_target)
        print("Using %d subset of val dataset" % args.limit_target)
    dataset = ConcatDataset([val_dataset])
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=False)
    return loader
Example #5
0
def get_val_dataloader(args, patches=False):
    dname = args.target
    if dname in new_nex_datasets:
        index_root = data_root = '/import/home/share/from_Nexperia_April2021/%s' % dname
    else:
        index_root = join(dirname(__file__), 'correct_txt_lists')
        data_root = join(dirname(__file__), 'kfold')
    names, labels = _dataset_info(join(index_root, "%s_val.txt" % dname))
    img_tr = get_nex_val_transformer(args)
    val_dataset = JigsawTestNewDataset(data_root,
                                       names,
                                       labels,
                                       patches=patches,
                                       img_transformer=img_tr,
                                       jig_classes=30)
    if args.limit_target and len(val_dataset) > args.limit_target:
        val_dataset = Subset(val_dataset, args.limit_target)
        print("Using %d subset of val dataset" % args.limit_target)
    dataset = ConcatDataset([val_dataset])
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=False)
    return loader
Example #6
0
def get_single_test_dataloader(args, dname, patches=False):
    names, labels = _dataset_info(join(dirname(__file__), 'txt_lists', '%s_test.txt' % dname))
    img_tr = get_val_transformer(args)
    
    # JigsawTestDataset return [unsorted_image, permutation_order-0, class_label]
    test_dataset = JigsawTestDataset(names, labels, patches=patches, img_transformer=img_tr, jig_classes=args.jigsaw_n_classes)
    if args.limit_target and len(test_dataset) > args.limit_target:
        test_dataset = Subset(test_dataset, args.limit_target)
        print("Using %d subset of test dataset" % args.limit_target)
    dataset = ConcatDataset([test_dataset])
    loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)
    return loader
Example #7
0
def get_tgt_dataloader(args, patches=False):
    '''
    Load whole domain dataset
    '''
    img_tr = get_nex_val_transformer(args)
    dname = args.target
    if dname in new_nex_datasets:
        index_root = data_root = '/import/home/share/from_Nexperia_April2021/%s' % dname
    else:
        index_root = join(dirname(__file__),'correct_txt_lists')
        data_root = join(dirname(__file__),'kfold')
    if args.downsample_target:
        name_train, labels_train = _dataset_info(join(index_root,"%s_train_down.txt" % dname))
        name_val, labels_val = _dataset_info(join(index_root, "%s_val_down.txt" % dname))
        name_test, labels_test = _dataset_info(join(index_root, "%s_test_down.txt" % dname))
    else:
        name_train, labels_train = _dataset_info(join(index_root,"%s_train.txt" % dname))
        name_val, labels_val = _dataset_info(join(index_root, "%s_val.txt" % dname))
        name_test, labels_test = _dataset_info(join(index_root, "%s_test.txt" % dname))

    tgt_train_dataset = JigsawTestNewDataset(data_root, name_train, labels_train, patches=patches, img_transformer=img_tr,jig_classes=30)
    tgt_val_dataset = JigsawTestNewDataset(data_root, name_val, labels_val, patches=patches, img_transformer=img_tr,
                                            jig_classes=30)
    tgt_test_dataset = JigsawTestNewDataset(data_root, name_test, labels_test, patches=patches, img_transformer=img_tr,
                                            jig_classes=30)

    tgt_dataset = ConcatDataset([tgt_train_dataset, tgt_val_dataset, tgt_test_dataset])
    loader = torch.utils.data.DataLoader(tgt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=False)
    return loader
Example #8
0
def get_val_dataloader(args):
    names, labels = _dataset_info(
        join(dirname(__file__), 'txt_lists', '%s_test.txt' % args.target))
    img_tr = get_val_transformer(args)
    val_dataset = BaselineDataset(names, labels, img_transformer=img_tr)
    if args.limit_target and len(val_dataset) > args.limit_target:
        val_dataset = Subset(val_dataset, args.limit_target)
        print("Using %d subset of val dataset" % args.limit_target)
    dataset = ConcatDataset([val_dataset])
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=4,
                                         pin_memory=True,
                                         drop_last=False)
    return loader