Example #1
0
    def test_2d_registration_train(self):

        import icon_registration.data as data
        import icon_registration.networks as networks
        import icon_registration.network_wrappers as network_wrappers
        import icon_registration.train as train
        import icon_registration.inverseConsistentNet as inverseConsistentNet

        import numpy as np
        import torch
        import random
        import os

        random.seed(1)
        torch.manual_seed(1)
        torch.cuda.manual_seed(1)
        np.random.seed(1)

        batch_size = 128

        d1, d2 = data.get_dataset_triangles("train",
                                            data_size=50,
                                            hollow=False,
                                            batch_size=batch_size)
        d1_t, d2_t = data.get_dataset_triangles("test",
                                                data_size=50,
                                                hollow=False,
                                                batch_size=batch_size)

        lmbda = 2048

        print("=" * 50)
        net = inverseConsistentNet.InverseConsistentNet(
            network_wrappers.FunctionFromVectorField(
                networks.tallUNet2(dimension=2)),
            # Our image similarity metric. The last channel of x and y is whether the value is interpolated or extrapolated,
            # which is used by some metrics but not this one
            lambda x, y: torch.mean((x[:, :1] - y[:, :1])**2),
            lmbda,
        )

        input_shape = next(iter(d1))[0].size()
        network_wrappers.assignIdentityMap(net, input_shape)
        net.cuda()
        optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
        net.train()

        y = np.array(train.train2d(net, optimizer, d1, d2, epochs=50))

        # Test that image similarity is good enough
        self.assertLess(np.mean(y[-5:, 1]), 0.1)

        # Test that folds are rare enough
        self.assertLess(np.mean(np.exp(y[-5:, 3] - 0.1)), 2)
        print(y)
GPUS = 4


phi = network_wrappers.FunctionFromVectorField(
    networks.tallUNet(unet=networks.UNet2ChunkyMiddle, dimension=3)
)
psi = network_wrappers.FunctionFromVectorField(networks.tallUNet2(dimension=3))

pretrained_lowres_net = inverseConsistentNet.InverseConsistentNet(
    network_wrappers.DoubleNet(phi, psi),
    lambda x, y: torch.mean((x - y) ** 2),
    100,
)

network_wrappers.assignIdentityMap(pretrained_lowres_net, input_shape)


network_wrappers.adjust_batch_size(pretrained_lowres_net, 12)
trained_weights = torch.load(
    "results/dd_l400_continue_rescalegrad2/knee_aligner_resi_net25500"
)

# trained_weights = torch.load("../results/dd_knee_l400_continue_smallbatch2/knee_aligner_resi_net9300")
# rained_weights = torch.load("../results/double_deformable_knee3/knee_aligner_resi_net22200")
pretrained_lowres_net.load_state_dict(trained_weights)

hires_net = inverseConsistentNet.InverseConsistentNet(
    network_wrappers.DownsampleNet(pretrained_lowres_net.regis_net, dimension=3),
    lambda x, y: torch.mean((x - y) ** 2),
    800,
image_A, image_B = (x[0].cuda() for x in next(zip(d1, d2)))

net = inverseConsistentNet.InverseConsistentNet(
    network_wrappers.DoubleNet(
        network_wrappers.FunctionFromVectorField(
            networks.tallUNet2(dimension=2)),
        network_wrappers.FunctionFromVectorField(
            networks.tallUNet2(dimension=2)),
    ),
    lambda x, y: torch.mean((x - y)**2),
    700,
)

input_shape = next(iter(d1))[0].size()
network_wrappers.assignIdentityMap(net, input_shape)
net.cuda()

import icon_registration.train as train

optim = torch.optim.Adam(net.parameters(), lr=0.0001)
net.train().cuda()

xs = []
for _ in range(240):
    y = np.array(train.train2d(net, optim, d1, d2, epochs=50))
    xs.append(y)
    x = np.concatenate(xs)
    plt.title("Loss curve for " + type(net.regis_net).__name__)
    plt.plot(x[:, :3])
    plt.savefig(footsteps.output_dir + f"loss.png")
Example #4
0
data_size = 50
d1, d2 = data.get_dataset_triangles(
    "train", data_size=data_size, hollow=True, batch_size=batch_size
)
d1_t, d2_t = data.get_dataset_triangles(
    "test", data_size=data_size, hollow=True, batch_size=batch_size
)

image_A, image_B = (x[0].cuda() for x in next(zip(d1, d2)))

net = inverseConsistentNet.InverseConsistentNet(
    network_wrappers.FunctionFromMatrix(networks.ConvolutionalMatrixNet()),
    lambda x, y: torch.mean((x - y) ** 2),
    100,
)
network_wrappers.assignIdentityMap(net, image_A.shape)
net.cuda()

import icon_registration.train as train

optim = torch.optim.Adam(net.parameters(), lr=0.00001)
net.train().cuda()


xs = []
for _ in range(240):
    y = np.array(train.train2d(net, optim, d1, d2, epochs=50))
    xs.append(y)
    x = np.concatenate(xs)
    plt.title("Loss curve for " + type(net.regis_net).__name__)
    plt.plot(x[:, :3])
Example #5
0
psi = network_wrappers.FunctionFromVectorField(networks.tallUNet2(dimension=3))

pretrained_lowres_net = network_wrappers.DoubleNet(phi, psi)

hires_net = inverseConsistentNet.InverseConsistentNet(
    network_wrappers.DoubleNet(
        network_wrappers.DownsampleNet(pretrained_lowres_net, dimension=3),
        network_wrappers.FunctionFromVectorField(networks.tallUNet2(dimension=3)),
    ),
    inverseConsistentNet.ssd_only_interpolated,
    3600,
)
BATCH_SIZE = 4
SCALE = 2  # 1 IS QUARTER RES, 2 IS HALF RES, 4 IS FULL RES
input_shape = [BATCH_SIZE, 1, 40 * SCALE, 96 * SCALE, 96 * SCALE]
network_wrappers.assignIdentityMap(hires_net, input_shape)

trained_weights = torch.load("results/hires_smart_6/knee_aligner_resi_net74700")
hires_net.load_state_dict(trained_weights)

fourth_net = inverseConsistentNet.InverseConsistentNet(
    network_wrappers.DoubleNet(
        hires_net.regis_net, 
        network_wrappers.FunctionFromVectorField(networks.tallUNet2(dimension=3)),
    ),
    inverseConsistentNet.ssd_only_interpolated,
    3600,
)

for p in fourth_net.regis_net.netPhi.parameters():
    p.requires_grad = False