Пример #1
0
import random
import inverseConsistentNet
import networks
import network_wrappers
import data
import describe

BATCH_SIZE = 32
SCALE = 1  # 1 IS QUARTER RES, 2 IS HALF RES, 4 IS FULL RES
input_shape = [BATCH_SIZE, 1, 40 * SCALE, 96 * SCALE, 96 * SCALE]

GPUS = 4


phi = network_wrappers.FunctionFromVectorField(
    networks.tallUNet(unet=networks.UNet2ChunkyMiddle, dimension=3)
)
psi = network_wrappers.FunctionFromVectorField(networks.tallUNet2(dimension=3))

pretrained_lowres_net = inverseConsistentNet.InverseConsistentNet(
    network_wrappers.DoubleNet(phi, psi),
    lambda x, y: torch.mean((x - y) ** 2),
    100,
)

network_wrappers.assignIdentityMap(pretrained_lowres_net, input_shape)


network_wrappers.adjust_batch_size(pretrained_lowres_net, 12)
trained_weights = torch.load(
    "results/dd_l400_continue_rescalegrad2/knee_aligner_resi_net1800"
import parent

from mermaidlite import compute_warped_image_multiNC, identity_map_multiN
import torch

import inverseConsistentNet
import networks
import data

BATCH_SIZE = 32
SCALE = 1  # 1 IS QUARTER RES, 2 IS HALF RES, 4 IS FULL RES
working_shape = [BATCH_SIZE, 1, 40 * SCALE, 96 * SCALE, 96 * SCALE]

net = inverseConsistentNet.InverseConsistentNet(
    networks.tallUNet(dimension=3),
    lmbda=166,
    input_shape=working_shape,
    random_sampling=False,
)
net.load_state_dict(torch.load("network_weights/lowres_knee_network"))

knees, medknees = data.get_knees_dataset()