Beispiel #1
0
    augment_noise=(args.augment_noise == 1))
tar_train_xtransform, _, tar_test_xtransform, tar_test_ytransform = get_augmenters_2d(
    augment_noise=(args.augment_noise == 1))
# load data
src_train = StronglyLabeledVolumeDataset(args.src_data_train,
                                         args.src_labels_train,
                                         input_shape,
                                         transform=src_train_xtransform,
                                         target_transform=src_train_ytransform)
src_test = StronglyLabeledVolumeDataset(args.src_data_test,
                                        args.src_labels_test,
                                        input_shape,
                                        transform=src_test_xtransform,
                                        target_transform=src_test_ytransform)
tar_train = UnlabeledVolumeDataset(args.tar_data_train,
                                   input_shape=input_shape,
                                   transform=tar_train_xtransform)
tar_test = StronglyLabeledVolumeDataset(args.tar_data_test,
                                        args.tar_labels_test,
                                        input_shape=input_shape,
                                        transform=tar_test_xtransform,
                                        target_transform=tar_test_ytransform)
src_train_loader = DataLoader(src_train, batch_size=args.train_batch_size // 2)
src_test_loader = DataLoader(src_test, batch_size=args.test_batch_size // 2)
tar_train_loader = DataLoader(tar_train, batch_size=args.train_batch_size // 2)
tar_test_loader = DataLoader(tar_test, batch_size=args.test_batch_size // 2)
"""
    Build the network
"""
print('[%s] Building the network' % (datetime.datetime.now()))
net = UNet_CORAL(feature_maps=args.fm,
if args.write_dir is not None:
    if not os.path.exists(args.write_dir):
        os.mkdir(args.write_dir)
    os.mkdir(os.path.join(args.write_dir, 'src_transform'))
    os.mkdir(os.path.join(args.write_dir, 'tar_transform'))

"""
    Load the data
"""
input_shape = (1, args.input_size[0], args.input_size[1])
print('[%s] Loading data' % (datetime.datetime.now()))
# augmenters
src_train_xtransform, _, src_test_xtransform, _ = get_augmenters_2d(augment_noise=(args.augment_noise == 1))
tar_train_xtransform, _, tar_test_xtransform, _ = get_augmenters_2d(augment_noise=(args.augment_noise == 1))
# load data
src_train = UnlabeledVolumeDataset(args.src_data_train, input_shape, transform=src_train_xtransform)
src_test = UnlabeledVolumeDataset(args.src_data_test, input_shape, transform=src_test_xtransform)
tar_train = UnlabeledVolumeDataset(args.tar_data_train, input_shape=input_shape, transform=tar_train_xtransform)
tar_test = UnlabeledVolumeDataset(args.tar_data_test, input_shape=input_shape, transform=tar_test_xtransform)
src_train_loader = DataLoader(src_train, batch_size=args.train_batch_size//2)
src_test_loader = DataLoader(src_test, batch_size=args.test_batch_size//2)
tar_train_loader = DataLoader(tar_train, batch_size=args.train_batch_size//2)
tar_test_loader = DataLoader(tar_test, batch_size=args.test_batch_size//2)

"""
    Setup optimization for unsupervised training
"""
print('[%s] Setting up optimization for unsupervised training' % (datetime.datetime.now()))
net = UNetAE(feature_maps=args.fm, levels=args.levels, group_norm=(args.group_norm == 1), lambda_dom=args.lambda_dom, s=args.input_size[0])

"""
Beispiel #3
0
    augment_noise=(args.augment_noise == 1))
# load data
src_train = StronglyLabeledVolumeDataset(args.src_data_train,
                                         args.src_labels_train,
                                         input_shape,
                                         transform=src_train_xtransform,
                                         target_transform=src_train_ytransform,
                                         preprocess='unit')
src_test = StronglyLabeledVolumeDataset(args.src_data_test,
                                        args.src_labels_test,
                                        input_shape,
                                        transform=src_test_xtransform,
                                        target_transform=src_test_ytransform,
                                        preprocess='unit')
tar_train = UnlabeledVolumeDataset(args.tar_data_train,
                                   input_shape=input_shape,
                                   transform=tar_train_xtransform,
                                   preprocess='unit')
tar_test = StronglyLabeledVolumeDataset(args.tar_data_test,
                                        args.tar_labels_test,
                                        input_shape=input_shape,
                                        transform=tar_test_xtransform,
                                        target_transform=tar_test_ytransform,
                                        preprocess='unit')
src_train_loader = DataLoader(src_train, batch_size=args.train_batch_size // 2)
src_test_loader = DataLoader(src_test, batch_size=args.test_batch_size // 2)
tar_train_loader = DataLoader(tar_train, batch_size=args.train_batch_size // 2)
tar_test_loader = DataLoader(tar_test, batch_size=args.test_batch_size // 2)
"""
    Build the network
"""
print('[%s] Building the network' % (datetime.datetime.now()))
Beispiel #4
0
    len_epoch=args.len_epoch,
    type=df_src['type'],
    train=True)
test_src = StronglyLabeledVolumeDataset(
    df_src['raw'],
    df_src['labels'],
    split_orientation=df_src['split-orientation'],
    split_location=df_src['split-location'],
    input_shape=input_shape,
    len_epoch=args.len_epoch,
    type=df_src['type'],
    train=False)
train_tar_ul = UnlabeledVolumeDataset(
    df_tar_ul['raw'],
    split_orientation=df_tar_ul['split-orientation'],
    split_location=df_tar_ul['split-location'],
    input_shape=input_shape,
    len_epoch=args.len_epoch,
    type=df_tar_ul['type'],
    train=True)
test_tar_ul = UnlabeledVolumeDataset(
    df_tar_ul['raw'],
    split_orientation=df_tar_ul['split-orientation'],
    split_location=df_tar_ul['split-location'],
    input_shape=input_shape,
    len_epoch=args.len_epoch,
    type=df_tar_ul['type'],
    train=False)
train_tar_l = StronglyLabeledVolumeDataset(
    df_tar_l['raw'],
    df_tar_l['labels'],
    split_orientation=df_tar_l['split-orientation'],
Beispiel #5
0
# load the network
print_frm('Loading network')
model_file = args.net
net = _load_net(model_file, device)

# load reference patch
print_frm('Loading data')
data_file = args.data_file
df = json.load(open(data_file))
n_domains = len(df['raw'])
input_shape = (1, input_size[0], input_size[1])
dataset_ref = UnlabeledVolumeDataset(
    df['raw'][domain_id],
    split_orientation=df['split-orientation'][domain_id],
    split_location=df['split-location'][domain_id],
    input_shape=input_shape,
    type=df['types'][domain_id],
    train=False)
x_ref = dataset_ref[0]

# compute embedding reference sample
print_frm('Computing embedding reference samples')
x_ref_t = tensor_to_device(torch.from_numpy(x_ref[np.newaxis, ...]),
                           device).float()
z_ref = net.encoder(x_ref_t)
z_ref = z_ref.detach().cpu().numpy()

# for all other domains, compute embeddings of randomly selected patches
print_frm('Computing embeddings remaining domains')
z = np.zeros((n_domains, n, net.bottleneck_dim))
Beispiel #6
0
# load reference patch
print_frm('Loading data')
data_file = args.data_file
df = json.load(open(data_file))
n_domains = len(df['raw'])
input_shape = (1, input_size[0], input_size[1])

# datasets
dss = []
for d in range(n_domains):
    print_frm('Loading %s' % df['raw'][d])
    dss.append(
        UnlabeledVolumeDataset(df['raw'][d],
                               split_orientation=df['split-orientation'][d],
                               split_location=df['split-location'][d],
                               input_shape=input_shape,
                               type=df['types'][d],
                               train=False,
                               len_epoch=n))

# for all other domains, compute embeddings of randomly selected patches
print_frm('Computing embeddings remaining domains')
sample_dists = np.zeros((n_domains, n_domains))
for d_src in range(n_domains):
    print_frm('Processing src domain %d/%d' % (d_src, n_domains))
    test_src = dss[d_src]
    dl = DataLoader(test_src, batch_size=args.batch_size)
    for d_tar in range(n_domains):
        print_frm('  Processing tar domain %d/%d' % (d_tar, n_domains))
        if d_src != d_tar:
            test_tar = dss[d_tar]