Ejemplo n.º 1
0
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    print(args)

    # Get all patient ids
    patient_ids = get_patient_ids('../dataset/', 400)
    print(len(patient_ids))

    for domain in range(10):

        train_patient_ids = patient_ids[:]
        test_patient_ids = patient_ids[domain]
        train_patient_ids.remove(test_patient_ids)

        train_dataset = MalariaData('../dataset/',
                                    domain_list=train_patient_ids,
                                    transform=True)
        train_size = int(0.80 * len(train_dataset))
        test_size = len(train_dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(
            train_dataset, [train_size, test_size])

        train_loader = data_utils.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             **kwargs)
        val_loader = data_utils.DataLoader(val_dataset,
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           **kwargs)
        print(args)

        # Train, val, test sets
        patient_ids = get_patient_ids('../dataset/', 400)
        print(len(patient_ids))

        train_patient_ids = patient_ids[:]
        test_patient_ids = patient_ids[0]
        train_patient_ids_unsupervised = 'C59P20'

        train_patient_ids.remove(test_patient_ids)
        train_patient_ids.remove(train_patient_ids_unsupervised)

        print(test_patient_ids, train_patient_ids, train_patient_ids_unsupervised)

        train_dataset = MalariaData('../dataset/', domain_list=train_patient_ids, transform=True)
        train_size = int(0.80 * len(train_dataset))
        test_size = len(train_dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, test_size])

        train_loader_supervised = data_utils.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                                        **kwargs)
        val_loader_supervised = data_utils.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)

        test_loader = data_utils.DataLoader(
            MalariaData('../dataset/', domain_list=[test_patient_ids]),
            batch_size=1,
            shuffle=False,
            **kwargs)

        train_loader_unsupervised = data_utils.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
Ejemplo n.º 3
0
model.qy.activation = nn.LeakyReLU()

batch_size = 14

# seeds
torch.manual_seed(1)
np.random.seed(1)

patient_ids = get_patient_ids('../dataset/', 400)
print(len(patient_ids))

train_patient_ids = patient_ids[:]
test_patient_ids = patient_ids[0]
train_patient_ids.remove(test_patient_ids)

train_dataset = MalariaData('../dataset/', domain_list=train_patient_ids)
train_size = int(0.80 * len(train_dataset))
test_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    train_dataset, [train_size, test_size])

train_loader = data_utils.DataLoader(train_dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     **kwargs)

# Get 8 images
for batch_idx, (x, y, d) in enumerate(train_loader):
    with torch.no_grad():
        x, y, d = x.cuda(), y.cuda(), d.cuda()
        break