示例#1
0
文件: main.py 项目: michellehan/ivc
def eval_create_data_loaders(train_transformation, eval_transformation, args):

    print('Test Dataset: %s' % (args.val_dir))

    eval_dataset = datasets.IVCdataset(args.val_csv, args.val_dir,
                                       eval_transformation)
    eval_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=2 * args.workers,  # Needs images twice as fast
        pin_memory=True,
        drop_last=False)

    args.class_to_idx = eval_dataset.class_to_idx
    return eval_loader
示例#2
0
文件: main.py 项目: michellehan/ivc
def create_data_loaders(train_transformation, eval_transformation, args):

    ############ training / testing diruse the same test dataset in official split
    print('Training Dataset: %s' % (args.train_dir))
    print('Validation Dataset: %s' % (args.val_dir))

    ############ Customized training dataset
    train_dataset = datasets.IVCdataset(args.train_csv, args.train_dir,
                                        train_transformation)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,  ### no custormized sampler, just batchsize
        shuffle=True,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True)

    ############ NOT EDITED FOR flag != 'full' ###############################
    if True:  # if args.flag == 'full':
        print('train loader for training on all labeled data!')
    #     train_loader = torch.utils.data.DataLoader(train_dataset,
    #                                                    batch_size=args.batch_size,      ### no custormized sampler, just batchsize
    #                                                    shuffle=True,
    #                                                    num_workers=args.workers,
    #                                                    pin_memory=True,
    #                                                    drop_last=True)

    else:
        sub_traindir = os.path.join(
            args.csvdir, 'train_val_official_%.2f_%s_cls%d.csv' %
            (args.train_portion, args.flag, args.num_classes))
        print('Change to Use Subset Training Dataset: %s' % (sub_traindir))
        sub_train_dataset = datasets.ChestXRayDataset(sub_traindir,
                                                      args.datadir,
                                                      train_transformation)

        if args.batch_size == args.labeled_batch_size:
            print(
                'train loader for training on subset labeled data (NO unlabeled data)!'
            )
            train_loader = torch.utils.data.DataLoader(
                sub_train_dataset,
                batch_size=args.
                batch_size,  ### no custormized sampler, just batchsize
                shuffle=True,
                num_workers=args.workers,
                pin_memory=True,
                drop_last=True)
        else:
            print(
                'train loader for training on subset labeled data (INCLUDE unlabeled data)!'
            )
            ### assing NO_LABEL to unlabeled samples
            labeled_idxs, unlabeled_idxs = data.relabel_dataset(
                dataset=train_dataset, labeled_dataset=sub_train_dataset)
            batch_sampler = data.TwoStreamBatchSampler(
                unlabeled_indices=unlabeled_idxs,
                labeled_indices=labeled_idxs,
                batch_size=args.batch_size,
                labeled_batch_size=args.labeled_batch_size)
            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_sampler=batch_sampler,
                num_workers=args.workers,
                pin_memory=True)
    ############ END: NOT EDITED FOR flag != 'full' ##############################

    ############ Customized validation dataset
    val_dataset = datasets.IVCdataset(args.val_csv, args.val_dir,
                                      eval_transformation)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=2 * args.workers,  # Needs images twice as fast
        pin_memory=True,
        drop_last=False
    )  # set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size

    args.class_to_idx = train_dataset.class_to_idx
    return train_loader, val_loader