Ejemplo n.º 1
0
def return_dataset_test(args):
    base_path = './data/txt/%s' % args.dataset
    root = './data/%s/' % args.dataset
    image_set_file_s = os.path.join(base_path, args.source + '_all' + '.txt')
    image_set_file_test = os.path.join(
        base_path,
        'unlabeled_target_images_' + args.target + '_%d.txt' % (args.num))
    if args.net == 'alexnet':
        crop_size = 227
    else:
        crop_size = 224
    data_transforms = {
        'test':
        transforms.Compose([
            ResizeImage(256),
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    target_dataset_unl = Imagelists_VISDA(image_set_file_test,
                                          root=root,
                                          transform=data_transforms['test'],
                                          test=True)
    class_list = return_classlist(image_set_file_test)
    num_images = len(target_dataset_unl)
    if args.net == 'alexnet':
        bs = 32
    else:
        bs = 24
    target_loader_unl = \
        torch.utils.data.DataLoader(target_dataset_unl,
                                    batch_size=bs * 2, num_workers=3,
                                    shuffle=False, drop_last=False)
    return target_loader_unl, class_list
Ejemplo n.º 2
0
def get_dataset(net,root,image_set_file_test):
    if net == 'alexnet':
        crop_size = 227
    else:
        crop_size = 224
    data_transforms = {
        'test': transforms.Compose([
            ResizeImage(256),
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    target_dataset_unl = Imagelists_VISDA(image_set_file_test, root=root,
                                            transform=data_transforms['test'],
                                            test=True)
    class_list = return_classlist(image_set_file_test)
    num_images = len(target_dataset_unl)
    if net == 'alexnet':
        bs = 1
    else:
        bs = 1

    target_loader_unl = torch.utils.data.DataLoader(target_dataset_unl, batch_size=bs, num_workers=3,shuffle=False, drop_last=False)
    return target_loader_unl,class_list
Ejemplo n.º 3
0
def return_dataset_train_eval(args):
    base_path = './data/txt/%s' % args.dataset
    root = './data/%s/' % args.dataset
    image_set_file_s = os.path.join(
        base_path, 'labeled_source_images_' + args.source + '.txt')
    image_set_file_t = os.path.join(
        base_path,
        'labeled_target_images_' + args.target + '_%d.txt' % (args.num))
    if args.net == 'alexnet':
        crop_size = 227
    else:
        crop_size = 224
    data_transforms = {
        'test':
        transforms.Compose([
            ResizeImage(256),
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    source_dataset = Imagelists_VISDA([image_set_file_s, image_set_file_t],
                                      root=root,
                                      transform=data_transforms['test'],
                                      test=True,
                                      multiple_files=True)
    if args.net == 'alexnet':
        bs = 32
    else:
        bs = 24
    source_loader = \
        torch.utils.data.DataLoader(source_dataset,
                                    batch_size=bs * 2, num_workers=3,
                                    shuffle=False, drop_last=False)
    return source_loader
Ejemplo n.º 4
0
def return_dataset(args):
    base_path = './data/txt'
    image_set_file_s = os.path.join(base_path, args.source +'_all' + '.txt')
    image_set_file_t = os.path.join(base_path, args.target + '_labeled' + '.txt')
    image_set_file_test = os.path.join(base_path, args.target + '_unl' + '.txt')
    if args.net == 'alexnet':
        crop_size = 227
    else:
        crop_size = 224
    data_transforms = {
        'train': transforms.Compose([
            ResizeImage(256),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            ResizeImage(256),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            ResizeImage(256),
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    source_dataset = Imagelists_VISDA(image_set_file_s, transform=data_transforms['train'])
    target_dataset = Imagelists_VISDA(image_set_file_t, transform=data_transforms['val'])
    target_dataset_unl = Imagelists_VISDA(image_set_file_test, transform=data_transforms['val'])
    class_list = return_classlist(image_set_file_s)
    print("%d classes in this dataset"%len(class_list))
    if args.net == 'alexnet':
        bs = 32
    else:
        bs = 24
    source_loader = torch.utils.data.DataLoader(source_dataset, batch_size=bs, num_workers=3, shuffle=True,
                                                drop_last=True)
    target_loader = torch.utils.data.DataLoader(target_dataset, batch_size=min(bs, len(target_dataset)),
                                                num_workers=3, shuffle=True, drop_last=True)
    target_loader_unl = torch.utils.data.DataLoader(target_dataset_unl, batch_size=bs * 2, num_workers=3,
                                                    shuffle=True, drop_last=True)
    return source_loader, target_loader, target_loader_unl,class_list
Ejemplo n.º 5
0
def return_dataset_test_unseen(args):
    #base_path = './data/txt/%s' % args.dataset
    #root = './data/%s/' % args.dataset
    #image_set_file_s = os.path.join(base_path, args.source + '_all' + '.txt')
    #image_set_file_s = "/cbica/home/bhaleram/comp_space/random/personal/iisc_project/MME/data/txt/final_test_target_painting.txt"
    #image_set_file_test = os.path.join(base_path,'unlabeled_target_images_' +args.target + '_%d.txt' % (args.num))
    image_set_file_test = "/cbica/home/bhaleram/comp_space/random/personal/iisc_project/MME/data/txt/final_test_target_painting.txt"
    if args.net == 'alexnet':
        crop_size = 227
    else:
        crop_size = 224
    data_transforms = {
        'test':
        transforms.Compose([
            ResizeImage(256),
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    target_dataset_unl = Imagelists_VISDA(image_set_file_test,
                                          transform=data_transforms['test'],
                                          test=True)
    class_list = return_classlist(
        "/cbica/home/bhaleram/comp_space/random/personal/iisc_project/MME/data/txt/labeled_source_images_real.txt"
    )
    print("%d classes in this dataset" % len(class_list))
    if args.net == 'alexnet':
        bs = 32
    else:
        bs = 24
    target_loader_unl = \
        torch.utils.data.DataLoader(target_dataset_unl,
                                    batch_size=bs * 2, num_workers=3,
                                    shuffle=False, drop_last=False)
    return target_loader_unl, class_list
Ejemplo n.º 6
0
def return_dataset_transfer(args):
    base_path = './data/txt/%s' % args.dataset
    root = './data/%s/' % args.dataset
    image_set_file_s = \
        os.path.join(base_path,
                     'labeled_source_images_' +
                     args.source + '.txt')
    image_set_file_t = \
        os.path.join(base_path,
                     'labeled_target_images_' +
                    args.target + '_%d.txt' % (args.num))

    image_set_file_t_val = \
        os.path.join(base_path,
                     'validation_target_images_' +
                     args.target + '_3.txt')
    image_set_file_unl = \
        os.path.join(base_path,
                     'unlabeled_target_images_' +
                     args.target + '_%d.txt' % (args.num))
    if args.net == 'alexnet':
        crop_size = 227
    else:
        crop_size = 224
    data_transforms = {
        'train':
        transforms.Compose([
            ResizeImage(256),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val':
        transforms.Compose([
            ResizeImage(256),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test':
        transforms.Compose([
            ResizeImage(256),
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    source_dataset = Imagelists_VISDA(image_set_file_s,
                                      root=root,
                                      transform=data_transforms['train'],
                                      test=True)
    target_dataset = Imagelists_VISDA(image_set_file_t,
                                      root=root,
                                      transform=data_transforms['val'],
                                      test=True)
    target_dataset_val = Imagelists_VISDA(image_set_file_t_val,
                                          root=root,
                                          transform=data_transforms['val'])
    target_dataset_unl = Imagelists_VISDA(image_set_file_unl,
                                          root=root,
                                          transform=data_transforms['val'])
    target_dataset_test = Imagelists_VISDA(image_set_file_unl,
                                           root=root,
                                           transform=data_transforms['test'])
    class_list = return_classlist(image_set_file_s)
    class_num_list = return_number_of_label_per_class(image_set_file_s,
                                                      len(class_list))

    if args.net == 'alexnet':
        bs = 32
    else:
        bs = 24
    source_loader = torch.utils.data.DataLoader(source_dataset,
                                                batch_size=bs,
                                                num_workers=3,
                                                shuffle=True,
                                                drop_last=True)
    target_loader = \
        torch.utils.data.DataLoader(target_dataset,
                                    batch_size=min(bs, len(target_dataset)),
                                    num_workers=3,
                                    shuffle=True, drop_last=True)
    target_loader_val = \
        torch.utils.data.DataLoader(target_dataset_val,
                                    batch_size=min(bs, len(target_dataset_val)),
                                    num_workers=3,
                                    shuffle=True, drop_last=True)
    target_loader_unl = \
        torch.utils.data.DataLoader(target_dataset_unl,
                                    batch_size=bs * 2, num_workers=3,
                                    shuffle=True, drop_last=True)
    target_loader_test = \
        torch.utils.data.DataLoader(target_dataset_test,
                                    batch_size=bs * 6, num_workers=3,
                                    shuffle=True, drop_last=True)
    return source_loader, target_loader, target_loader_unl, \
          target_loader_val, target_loader_test, class_num_list, class_list
Ejemplo n.º 7
0
def return_dataset(args):
    base_path = './data/txt'
    image_set_file_s = os.path.join(base_path, args.source +'_all' + '.txt')
    image_set_file_t = os.path.join(base_path, args.target + '_labeled' + '.txt')
    image_set_file_test = os.path.join(base_path, args.target + '_unl' + '.txt')
    if args.source_model == 'alexnet':
        crop_size = 227
    else:
        crop_size = 224

    train_transforms = []
    val_transforms = []
    test_transforms = []

    train_transforms.append(ResizeImage(256))
    train_transforms.append(transforms.RandomHorizontalFlip())
    train_transforms.append(transforms.RandomCrop(crop_size))

    train_transforms.append(transforms.ToTensor())
    train_transforms.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))

    val_transforms.append(ResizeImage(256))
    val_transforms.append(transforms.RandomHorizontalFlip())
    val_transforms.append(transforms.RandomCrop(crop_size))
    # if args.use_rotation:
    #     val_transforms.append(transforms.RandomRotation(degrees=(0, 360)))
    val_transforms.append(transforms.ToTensor())
    val_transforms.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))

    test_transforms.append(ResizeImage(256))
    test_transforms.append(transforms.CenterCrop(crop_size))
    test_transforms.append(transforms.ToTensor())
    test_transforms.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))

    data_transforms = {
        'train': transforms.Compose(train_transforms),
        'val': transforms.Compose(val_transforms),
        'test': transforms.Compose(test_transforms),
    }
    source_dataset = Imagelists_VISDA(image_set_file_s, transform=data_transforms['train'],
                                      use_contour=False)
    target_dataset = Imagelists_VISDA(image_set_file_t, transform=data_transforms['val'],
                                      use_contour=False)
    target_dataset_unl = Imagelists_VISDA(image_set_file_test, transform=data_transforms['val'],
                                          use_contour=False)

    print('source size: {}, target size: {}, target_unl size: {}'
          .format(len(source_dataset), len(target_dataset), len(target_dataset_unl)))

    class_list = return_classlist(image_set_file_s)
    print("%d classes in this dataset"%len(class_list))
    if args.source_model == 'alexnet':
        bs = 32
    elif 'se_resne' in args.source_model:
        bs = 12
    else:
        bs = 24
    source_loader = torch.utils.data.DataLoader(source_dataset, batch_size=bs, num_workers=3, shuffle=True,
                                                drop_last=True)
    target_loader = torch.utils.data.DataLoader(target_dataset, batch_size=min(bs, len(target_dataset)),
                                                num_workers=3, shuffle=True, drop_last=True)
    target_loader_unl = torch.utils.data.DataLoader(target_dataset_unl, batch_size=bs * 2, num_workers=3,
                                                    shuffle=True, drop_last=True)
    return source_loader, target_loader, target_loader_unl, class_list
Ejemplo n.º 8
0
def return_dataset(args):
    base_path = './data/txt/%s' % args.dataset
    root = './data/%s/' % args.dataset

    image_set_file_s = \
        os.path.join(base_path,
                     'labeled_source_images_' +
                     args.source + '.txt')
    image_set_file_t = \
        os.path.join(base_path,
                     'labeled_target_images_' +
                     args.target + '_%d.txt' % (args.num))
    image_set_file_t_val = \
        os.path.join(base_path,
                     'validation_target_images_' +
                     args.target + '_3.txt')
    image_set_file_unl = \
        os.path.join(base_path,
                     'unlabeled_target_images_' +
                     args.target + '_%d.txt' % (args.num))

    if args.net == 'alexnet':
        crop_size = 227
    else:
        crop_size = 224
    data_transforms = {
        'train':
        transforms.Compose([
            ResizeImage(256),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val':
        transforms.Compose([
            ResizeImage(256),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'strong':
        transforms.Compose([
            ResizeImage(256),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(crop_size),
            RandAugmentMC(n=2, m=10),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]),
        'test':
        transforms.Compose([
            ResizeImage(256),
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    source_dataset = Imagelists_VISDA(
        image_set_file_s,
        root=root,
        transform=data_transforms['train'],
        strong_transform=data_transforms['strong'])
    target_dataset = Imagelists_VISDA(
        image_set_file_t,
        root=root,
        transform=data_transforms['val'],
        strong_transform=data_transforms['strong'])
    target_dataset_val = Imagelists_VISDA(
        image_set_file_t_val,
        root=root,
        transform=data_transforms['val'],
        strong_transform=data_transforms['strong'])
    target_dataset_unl = Imagelists_VISDA(
        image_set_file_unl,
        root=root,
        transform=data_transforms['val'],
        strong_transform=data_transforms['strong'])
    target_dataset_test = Imagelists_VISDA(image_set_file_unl,
                                           root=root,
                                           transform=data_transforms['test'],
                                           test=True)
    class_list = return_classlist(image_set_file_s)
    print("%d classes in this dataset" % len(class_list))

    return source_dataset, target_dataset, target_dataset_unl, target_dataset_val, target_dataset_test, class_list
Ejemplo n.º 9
0
def return_dataset_s4l_fixmatch(args):
    base_path = './data/txt/%s' % args.dataset
    root = './data/%s/' % args.dataset
    image_set_file_s = \
        os.path.join(base_path,
                     'labeled_source_images_' +
                     args.source + '.txt')
    image_set_file_t = \
        os.path.join(base_path,
                     'labeled_target_images_' +
                     args.target + '_%d.txt' % (args.num))
    image_set_file_t_val = \
        os.path.join(base_path,
                     'validation_target_images_' +
                     args.target + '_3.txt')
    image_set_file_unl = \
        os.path.join(base_path,
                     'unlabeled_target_images_' +
                     args.target + '_%d.txt' % (args.num))

    if args.net == 'alexnet':
        crop_size = 227
    else:
        crop_size = 224
    data_transforms = {
        'train':
        TransformRotate(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225],
                        crop_size=crop_size),
        'train_ut':
        TransformRotateFix(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225],
                           crop_size=crop_size),
        'val':
        transforms.Compose([
            ResizeImage(256),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test':
        transforms.Compose([
            ResizeImage(256),
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    source_dataset = Imagelists_VISDA(image_set_file_s,
                                      root=root,
                                      transform=data_transforms['train'],
                                      rotate=True)
    target_dataset = Imagelists_VISDA(image_set_file_t,
                                      root=root,
                                      transform=data_transforms['train'],
                                      rotate=True)
    target_dataset_val = Imagelists_VISDA(image_set_file_t_val,
                                          root=root,
                                          transform=data_transforms['val'])
    target_dataset_unl = Imagelists_VISDA(
        image_set_file_unl,
        root=root,
        transform=data_transforms['train_ut'],
        rotate=True)
    target_dataset_test = Imagelists_VISDA(image_set_file_unl,
                                           root=root,
                                           transform=data_transforms['test'])
    class_list = return_classlist(image_set_file_s)
    print("%d classes in this dataset" % len(class_list))
    if args.net == 'alexnet':
        bs = 32  # 32
    else:
        bs = 10  # 24
    num_workers = 3
    source_loader = \
        torch.utils.data.DataLoader(source_dataset,
                                    batch_size=bs,
                                    num_workers=num_workers, shuffle=True,
                                    drop_last=True)
    target_loader = \
        torch.utils.data.DataLoader(target_dataset,
                                    batch_size=min(bs, len(target_dataset)),
                                    num_workers=num_workers,
                                    shuffle=True, drop_last=True)
    target_loader_val = \
        torch.utils.data.DataLoader(target_dataset_val,
                                    batch_size=min(bs,
                                                   len(target_dataset_val)),
                                    num_workers=num_workers,
                                    shuffle=True, drop_last=True)  # drop_last should be set to False
    target_loader_unl = \
        torch.utils.data.DataLoader(target_dataset_unl,
                                    batch_size=bs, num_workers=num_workers,
                                    shuffle=True, drop_last=True)
    target_loader_test = \
        torch.utils.data.DataLoader(target_dataset_test,
                                    batch_size=bs, num_workers=num_workers,
                                    shuffle=True, drop_last=True)  # drop_last should be set to False
    return source_loader, target_loader, target_loader_unl, \
        target_loader_val, target_loader_test, class_list
Ejemplo n.º 10
0
if net == 'alexnet':
    crop_size = 227
else:
    crop_size = 224
data_transforms = {
    'test':
    transforms.Compose([
        ResizeImage(256),
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}
target_dataset_unl = Imagelists_VISDA(image_set_file_test,
                                      root=root,
                                      transform=data_transforms['test'],
                                      test=True)
class_list = return_classlist(image_set_file_test)
num_images = len(target_dataset_unl)
if net == 'alexnet':
    bs = 1
else:
    bs = 1

target_loader_unl = torch.utils.data.DataLoader(target_dataset_unl,
                                                batch_size=bs,
                                                num_workers=3,
                                                shuffle=False,
                                                drop_last=False)

# Deinfining the pytorch networks
Ejemplo n.º 11
0
def return_dataset(args):
    print("entered return_dataset function")
    base_path = './data/txt'
    image_set_file_s1 = os.path.join(base_path, args.source1 + '_all' + '.txt')
    image_set_file_s2 = os.path.join(base_path, args.source2 + '_all' + '.txt')
    #image_set_file_s2= os.path.join(base_path, args.source2 +'_all' + '.txt')
    image_set_file_t = os.path.join(base_path, args.target + '_all' + '.txt')
    #image_set_file_test = os.path.join(base_path, args.target + '_unl' + '.txt')
    if args.net == 'alexnet':
        print("network is alexnet, crop size is 227")
        crop_size = 227
    else:
        print("network is ", args.net, " crop size is 224")
        crop_size = 224
    print("transforming data")
    data_transforms = {
        'train1':
        transforms.Compose([
            ResizeImage(256),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'train2':
        transforms.Compose([
            ResizeImage(256),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val':
        transforms.Compose([
            ResizeImage(256),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test':
        transforms.Compose([
            ResizeImage(256),
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    print("reading datasets")
    source_dataset1 = Imagelists_VISDA(image_set_file_s1,
                                       transform=data_transforms['train1'])
    source_dataset2 = Imagelists_VISDA(image_set_file_s2,
                                       transform=data_transforms['train2'])
    target_dataset = Imagelists_VISDA(image_set_file_t,
                                      transform=data_transforms['val'])
    #target_dataset_unl = Imagelists_VISDA(image_set_file_test, transform=data_transforms['val'])

    class_list1 = return_classlist(image_set_file_s1)
    class_list2 = return_classlist(image_set_file_s1)
    print("%d classes in this dataset (based on source 1)" % len(class_list1))
    print("%d classes in this dataset (based on source 2)" % len(class_list2))

    if args.net == 'alexnet':
        print("network is alexnet, batch size is 32")
        bs = 32
    else:
        print("network is ", args.net, " batch size is 24")
        bs = 24
    print("loading datasets")
    source_loader1 = torch.utils.data.DataLoader(source_dataset1,
                                                 batch_size=bs,
                                                 num_workers=3,
                                                 shuffle=True,
                                                 drop_last=True)
    source_loader2 = torch.utils.data.DataLoader(source_dataset2,
                                                 batch_size=bs,
                                                 num_workers=3,
                                                 shuffle=True,
                                                 drop_last=True)
    target_loader = torch.utils.data.DataLoader(target_dataset,
                                                batch_size=min(
                                                    bs, len(target_dataset)),
                                                num_workers=3,
                                                shuffle=True,
                                                drop_last=True)
    #target_loader_unl = torch.utils.data.DataLoader(target_dataset_unl, batch_size=bs * 2, num_workers=3,
    #shuffle=True, drop_last=True)
    print("returning dataset")
    return source_loader1, source_loader2, target_loader, class_list1, class_list2