コード例 #1
0
def train(model_name, gpu_id):

    model_dir = '../models/' + model_name
    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    gpu = '/gpu:' + str(gpu_id)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    with tf.device(gpu):
        model = networks.unet(vol_size, nf_enc, nf_dec)
        model.compile(optimizer=Adam(lr=lr), loss=[
                      losses.cc3D(), losses.gradientLoss('l2')], loss_weights=[1.0, reg_param])
        # model.load_weights('../models/udrnet2/udrnet1_1/120000.h5')

    train_example_gen = datagenerators.example_gen(train_vol_names)
    zero_flow = np.zeros((1, vol_size[0], vol_size[1], vol_size[2], 3))

    for step in xrange(0, n_iterations):

        X = train_example_gen.next()[0]
        train_loss = model.train_on_batch(
            [X, atlas_vol], [atlas_vol, zero_flow])

        if not isinstance(train_loss, list):
            train_loss = [train_loss]

        printLoss(step, 1, train_loss)

        if(step % model_save_iter == 0):
            model.save(model_dir + '/' + str(step) + '.h5')
コード例 #2
0
def train(model, gpu_id, lr, n_iterations, reg_param, model_save_iter,
          load_iter):

    model_dir = '/home/ys895/MAS3_Models'
    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    gpu = '/gpu:' + str(gpu_id)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    # UNET filters
    nf_enc = [16, 32, 32, 32]
    if (model == 'vm1'):
        nf_dec = [32, 32, 32, 32, 8, 8, 3]
    else:
        nf_dec = [32, 32, 32, 32, 32, 16, 16, 3]

    with tf.device(gpu):
        model = networks.unet(vol_size, nf_enc, nf_dec)
        if (load_iter != 0):
            model.load_weights('/home/ys895/MAS3_Models/' + str(load_iter) +
                               '.h5')

        model.compile(optimizer=Adam(lr=lr),
                      loss=[losses.cc3D(),
                            losses.gradientLoss('l2')],
                      loss_weights=[1.0, reg_param])
        # model.load_weights('../models/udrnet2/udrnet1_1/120000.h5')

    # return the data, add one more dimension into the data
    train_example_gen = datagenerators.example_gen(train_vol_names)
    zero_flow = np.zeros((1, vol_size[0], vol_size[1], vol_size[2], 3))

    # In this part, the code inputs the data into the model
    # Before this part, the model was set
    for step in range(1, n_iterations + 1):
        # choose randomly one of the atlas from the atlas_list
        rand_num = random.randint(0, list_num - 1)
        atlas_vol = atlas_list[rand_num]

        #Parameters for training : X(train_vol) ,atlas_vol(atlas) ,zero_flow
        X = train_example_gen.__next__()[0]
        train_loss = model.train_on_batch([atlas_vol, X], [X, zero_flow])

        if not isinstance(train_loss, list):
            train_loss = [train_loss]

        printLoss(step, 1, train_loss)

        if (step % model_save_iter == 0):
            model.save(model_dir + '/' + str(load_iter + step) + '.h5')
コード例 #3
0
def test(model_name,
         iter_num,
         gpu_id,
         vol_size=(160, 192, 224),
         nf_enc=[16, 32, 32, 32],
         nf_dec=[32, 32, 32, 32]):
    """
    test

    nf_enc and nf_dec
    #nf_dec = [32,32,32,32,32,16,16,3]
    # This needs to be changed. Ideally, we could just call load_model, and we wont have to
    # specify the # of channels here, but the load_model is not working with the custom loss...
    """

    gpu = '/gpu:' + str(gpu_id)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    # load weights of model
    with tf.device(gpu):
        full_model, train_model = networks.autoencoder(vol_size, nf_enc,
                                                       nf_dec)
        full_model.load_weights('../models/' + model_name + '/' +
                                str(iter_num) + '.h5')

    val_example_gen = datagenerators.example_gen(val_vol_names)

    # train. Note: we use train_on_batch and design out own print function as this has enabled
    # faster development and debugging, but one could also use fit_generator and Keras callbacks.
    total_loss = 0
    for step in range(1):

        # get data
        X = next(val_example_gen)[0]

        # get output
        output, enc = full_model.predict([X])

        loss = tf.reduce_mean(tf.square(output - X))

        # print the loss.
        print(step, 0, loss)
        total_loss += loss

        slices(output[0])

    print(total_loss)
コード例 #4
0
def test(gpu, ref_dir, mov_dir, model, init_model_file):
    """
    model training function
    :param gpu: integer specifying the gpu to use
    :param atlas_file: atlas filename. So far we support npz file with a 'vol' variable
    :param model: either vm1 or vm2 (based on CVPR 2018 paper)
    :param init_model_file: the model directory to load from
    """

    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    device = "cuda"

    # Prepare the vm1 or vm2 model and send to device
    nf_enc = [16, 32, 32, 32]
    if model == "vm1":
        nf_dec = [32, 32, 32, 32, 8, 8]
    elif model == "vm2":
        nf_dec = [32, 32, 32, 32, 32, 16, 16]

    # Set up model
    vol_size = [2912, 2912]
    model = cvpr2018_net(vol_size, nf_enc, nf_dec)
    model.to(device)
    model.load_state_dict(
        torch.load(init_model_file, map_location=lambda storage, loc: storage))

    # set up
    ref_vol_names = glob.glob(os.path.join(ref_dir, '*.npy'))
    mov_vol_names = glob.glob(os.path.join(mov_dir, '*npy'))
    nums = len(ref_vol_names)

    for k in range(0, nums):
        refs, movs = datagenerators.example_gen(ref_vol_names,
                                                mov_vol_names,
                                                batch_size=1)
        input_fixed = torch.from_numpy(refs).to(device).float()
        input_fixed = input_fixed.permute(0, 3, 1, 2)
        input_moving = torch.from_numpy(movs).to(device).float()
        input_moving = input_moving.permute(0, 3, 1, 2)

        # Use this to warp segments
        # trf = SpatialTransformer(input_fixed.shape[2:], mode='nearest')
        # trf.to(device)
        warp, flow = model(input_moving, input_fixed)
        flow_save = sitk.GetImageFromArray(flow.cpu().detach().numpy())
        # sitk.WriteImage(flow_save,'D:\peizhunsd\data\\flow_img\\' + str(k) + '.nii')

        # 位移向量场的可视化
        # addimage(input_fixed,input_moving,warp,k)   # 可视化结果
        dice_score = metrics.dice_score(warp, input_fixed)
        print('相似性度量dice:', dice_score)
コード例 #5
0
ファイル: train.py プロジェクト: ymcidence/voxelmorph
def train(model,save_name, gpu_id, lr, n_iterations, reg_param, model_save_iter):

    model_dir = '../models/' + save_name
    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    gpu = '/gpu:' + str(gpu_id)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))


    # UNET filters
    nf_enc = [16,32,32,32]
    if(model == 'vm1'):
        nf_dec = [32,32,32,32,8,8,3]
    else:
        nf_dec = [32,32,32,32,32,16,16,3]

    with tf.device(gpu):
        model = networks.unet(vol_size, nf_enc, nf_dec)
        model.compile(optimizer=Adam(lr=lr), loss=[
                      losses.cc3D(), losses.gradientLoss('l2')], loss_weights=[1.0, reg_param])
        # model.load_weights('../models/udrnet2/udrnet1_1/120000.h5')

    train_example_gen = datagenerators.example_gen(train_vol_names)
    zero_flow = np.zeros((1, vol_size[0], vol_size[1], vol_size[2], 3))

    for step in range(0, n_iterations):

        X = train_example_gen.__next__()[0]
        train_loss = model.train_on_batch(
            [X, atlas_vol], [atlas_vol, zero_flow])

        if not isinstance(train_loss, list):
            train_loss = [train_loss]

        printLoss(step, 1, train_loss)

        if(step % model_save_iter == 0):
            model.save(model_dir + '/' + str(step) + '.h5')
コード例 #6
0
def train(data_dir, fixed_image, model_dir, device, lr, nb_epochs, AAN_param,
          steps_per_epoch, batch_size, load_model_file, initial_epoch,
          DLR_model):

    # prepare data files
    # inside the folder are npz files with the 'vol' and 'label'.
    train_vol_names = glob.glob(os.path.join(data_dir, '*.npz'))
    random.shuffle(train_vol_names)  # shuffle volume list_
    assert len(train_vol_names) > 0, "Could not find any training data"
    vol_size = [144, 192, 160]

    # load atlas from provided files, if atlas-based registration
    if fixed_image != './':
        fixed_vol = np.load(fixed_image)['vol'][np.newaxis, ..., np.newaxis]

    def FAIM_loss(y_true, y_pred):
        return losses.Grad('l2').loss(
            y_true, y_pred) + 1e-5 * losses.NJ_loss(y_true, y_pred)

    assert DLR_model in [
        'VM', 'FAIM'
    ], 'DLR_model should be one of VM or FAIM, found %s' % LBR_model
    if DLR_model == 'FAIM':
        reg_loss = FAIM_loss
    else:
        reg_loss = losses.Grad('l2').loss

    # prepare model folder
    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    # device handling
    if 'gpu' in device:
        if '0' in device:
            device = '/gpu:0'
        if '1' in device:
            device = '/gpu:1'
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        set_session(tf.Session(config=config))
    else:
        device = '/cpu:0'

    # prepare the model
    with tf.device(device):
        model = networks.AAN(vol_size, DLR_model)

        # load initial weights
        if load_model_file != './':
            print('loading', load_model_file)
            model.load_weights(load_model_file)

        # save first iteration
        model.save(os.path.join(model_dir, '%02d.h5' % initial_epoch))

    # data generator
    train_example_gen = datagenerators.example_gen(train_vol_names,
                                                   batch_size=batch_size,
                                                   return_boundary=True)
    if fixed_image != './':
        fixed_vol_bs = np.repeat(fixed_vol, batch_size, axis=0)
        data_gen = datagenerators.gen_atlas(train_example_gen,
                                            fixing_vol_bs,
                                            batch_size=batch_size)
    else:
        data_gen = datagenerators.gen_s2s(train_example_gen,
                                          batch_size=batch_size)

    # prepare callbacks
    save_file_name = os.path.join(model_dir, '{epoch:02d}.h5')

    # fit generator
    with tf.device(device):

        save_callback = ModelCheckpoint(save_file_name)

        # compile
        model.compile(
            optimizer=Adam(lr=lr),
            loss=['mse', reg_loss,
                  losses.Grad('l1').loss_with_boundary],
            loss_weights=[1.0, 0.01, AAN_param])

        # fit
        model.fit_generator(data_gen,
                            initial_epoch=initial_epoch,
                            epochs=nb_epochs,
                            callbacks=[save_callback],
                            steps_per_epoch=steps_per_epoch,
                            verbose=1)
コード例 #7
0
def train(model,
          pretrained_path,
          model_name,
          gpu_id,
          lr,
          n_iterations,
          use_mi,
          gamma,
          num_bins,
          patch_size,
          max_clip,
          reg_param,
          model_save_iter,
          local_mi,
          sigma_ratio,
          batch_size=1):
    """
    model training function
    :param model: either vm1 or vm2 (based on CVPR 2018 paper)
    :param model_dir: the model directory to save to
    :param gpu_id: integer specifying the gpu to use
    :param lr: learning rate
    :param n_iterations: number of training iterations
    :param reg_param: the smoothness/reconstruction tradeoff parameter (lambda in CVPR paper)
    :param model_save_iter: frequency with which to save models
    :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size
    """
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

    restrict_GPU_tf(str(gpu_id))
    restrict_GPU_keras(str(gpu_id))

    train_labels = sio.loadmat('../data/labels.mat')['labels'][0]
    n_labels = train_labels.shape[0]

    normalized_atlas_vol = atlas_vol / np.max(atlas_vol) * max_clip

    atlas_seg = datagenerators.split_seg_into_channels(seg, train_labels)
    atlas_seg = datagenerators.downsample(atlas_seg)

    model_dir = "../models/" + model_name
    # prepare model folder
    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    # GPU handling
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    # UNET filters for voxelmorph-1 and voxelmorph-2,
    # these are architectures presented in CVPR 2018
    nf_enc = [16, 32, 32, 32]
    if model == 'vm1':
        nf_dec = [32, 32, 32, 32, 8, 8]
    else:
        nf_dec = [32, 32, 32, 32, 32, 16, 16]

    # prepare the model
    # in the CVPR layout, the model takes in [image_1, image_2] and outputs [warped_image_1, flow]
    # in the experiments, we use image_2 as atlas

    bin_centers = np.linspace(0, max_clip, num_bins * 2 + 1)[1::2]
    loss_function = losses.mutualInformation(bin_centers,
                                             max_clip=max_clip,
                                             local_mi=local_mi,
                                             patch_size=patch_size,
                                             sigma_ratio=sigma_ratio)

    model = networks.cvpr2018_net(vol_size,
                                  nf_enc,
                                  nf_dec,
                                  use_seg=True,
                                  n_seg=len(train_labels))
    model.compile(optimizer=Adam(lr=lr),
                  loss=[
                      loss_function,
                      losses.gradientLoss('l2'),
                      sparse_categorical_crossentropy
                  ],
                  loss_weights=[1 if use_mi else 0, reg_param, gamma])

    # if you'd like to initialize the data, you can do it here:
    if pretrained_path != None and pretrained_path != '':
        model.load_weights(pretrained_path)

    # prepare data for training
    train_example_gen = datagenerators.example_gen(train_vol_names,
                                                   return_segs=True,
                                                   seg_dir=train_seg_dir)
    zero_flow = np.zeros([batch_size, *vol_size, 3])

    # train. Note: we use train_on_batch and design out own print function as this has enabled
    # faster development and debugging, but one could also use fit_generator and Keras callbacks.
    for step in range(0, n_iterations):

        # get data
        X = next(train_example_gen)
        X_seg = X[1]

        X_seg = datagenerators.split_seg_into_channels(X_seg, train_labels)
        X_seg = datagenerators.downsample(X_seg)

        # train
        train_loss = model.train_on_batch(
            [X[0], normalized_atlas_vol, X_seg],
            [normalized_atlas_vol, zero_flow, atlas_seg])
        if not isinstance(train_loss, list):
            train_loss = [train_loss]

        # print the loss.
        print_loss(step, 1, train_loss)

        # save model
        if step % model_save_iter == 0:
            model.save(os.path.join(model_dir, str(step) + '.h5'))
コード例 #8
0
def train(data_dir,
          atlas_file,
          model_dir,
          model,
          gpu_id,
          lr,
          nb_epochs,
          prior_lambda,
          image_sigma,
          mean_lambda,
          steps_per_epoch,
          batch_size,
          load_model_file,
          bidir,
          atlas_wt,
          bias_mult,
          smooth_pen_layer,
          data_loss,
          reg_param,
          ncc_win,
          initial_epoch=0):
    """
    model training function
    :param data_dir: folder with npz files for each subject.
    :param atlas_file: atlas filename. So far we support npz file with a 'vol' variable
    :param model_dir: model folder to save to
    :param gpu_id: integer specifying the gpu to use
    :param lr: learning rate
    :param nb_epochs: number of training iterations
    :param prior_lambda: the prior_lambda, the scalar in front of the smoothing laplacian, in MICCAI paper
    :param image_sigma: the image sigma in MICCAI paper
    :param steps_per_epoch: frequency with which to save models
    :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size
    :param load_model_file: optional h5 model file to initialize with
    :param bidir: logical whether to use bidirectional cost function
    """
    
     
    # prepare data files
    # we have data arranged in train/validate/test folders
    # inside each folder is a /vols/ and a /asegs/ folder with the volumes
    # and segmentations. All of our papers use npz formated data.
    train_vol_names = glob.glob(data_dir)
    train_vol_names = [f for f in train_vol_names if 'ADNI' not in f]
    random.shuffle(train_vol_names)  # shuffle volume list
    assert len(train_vol_names) > 0, "Could not find any training data"

    # data generator
    train_example_gen = datagenerators.example_gen(train_vol_names, batch_size=batch_size)

    # prepare the initial weights for the atlas "layer"
    if atlas_file is None or atlas_file == "":
        nb_atl_creation = 100
        print('creating "atlas" by averaging %d subjects' % nb_atl_creation)
        x_avg = 0
        for _ in range(nb_atl_creation):
            x_avg += next(train_example_gen)[0][0,...,0]
        x_avg /= nb_atl_creation

        x_avg = x_avg[np.newaxis,...,np.newaxis]
        atlas_vol = x_avg
    else:
        atlas_vol = np.load(atlas_file)['vol'][np.newaxis, ..., np.newaxis]
    vol_size = atlas_vol.shape[1:-1]

    # Diffeomorphic network architecture used in MICCAI 2018 paper
    nf_enc = [16,32,32,32]
    nf_dec = [32,32,32,32,16,3] 
    if model == 'm1':
        pass
    elif model == 'm1double':
        nf_enc = [f*2 for f in nf_enc]
        nf_dec = [f*2 for f in nf_dec]

    # prepare model folder
    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)


    assert data_loss in ['mse', 'cc', 'ncc'], 'Loss should be one of mse or cc, found %s' % data_loss
    if data_loss in ['ncc', 'cc']:
        data_loss = losses.NCC(win=[ncc_win]*3).loss      
    else:
        data_loss = lambda y_t, y_p: K.mean(K.square(y_t-y_p))

    # gpu handling
    gpu = '/gpu:' + str(gpu_id)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    # prepare the model
    with tf.device(gpu):
        # the MICCAI201 model takes in [image_1, image_2] and outputs [warped_image_1, velocity_stats]
        # in these experiments, we use image_2 as atlas
        model = networks.img_atlas_diff_model(vol_size, nf_enc, nf_dec, 
                                            atl_mult=1, bidir=bidir,
                                            smooth_pen_layer=smooth_pen_layer)



        # compile
        mean_layer_loss = lambda _, y_pred: mean_lambda * K.mean(K.square(y_pred))

        flow_vol_shape = model.outputs[-2].shape[1:-1]
        loss_class = losses.Miccai2018(image_sigma, prior_lambda, flow_vol_shape=flow_vol_shape)
        if bidir:
            model_losses = [data_loss,
                            lambda _,y_p: data_loss(model.get_layer('atlas').output, y_p),
                            mean_layer_loss,
                            losses.Grad('l2').loss]
            loss_weights = [atlas_wt, 1-atlas_wt, 1, reg_param]
        else:
            model_losses = [loss_class.recon_loss, loss_class.kl_loss, mean_layer_loss]
            loss_weights = [1, 1, 1]
        model.compile(optimizer=Adam(lr=lr), loss=model_losses, loss_weights=loss_weights)
    
        # set initial weights in model
        model.get_layer('atlas').set_weights([atlas_vol[0,...]])

        # load initial weights. # note this overloads the img_param weights
        if load_model_file is not None and len(load_model_file) > 0:
            model.load_weights(load_model_file, by_name=True)



    # save first iteration
    model.save(os.path.join(model_dir, '%02d.h5' % initial_epoch))

    # atlas_generator specific to this model. Once we're convinced of this, move to datagenerators
    def atl_gen(gen):  
        zero_flow = np.zeros([batch_size, *vol_size, len(vol_size)])
        zero_flow_half = np.zeros([batch_size] + [f//2 for f in vol_size] + [len(vol_size)])
        while 1:
            x2 = next(train_example_gen)[0]
            # TODO: note this is the opposite of train_miccai and it might be confusing.
            yield ([atlas_vol, x2], [x2, atlas_vol, zero_flow, zero_flow])

    atlas_gen = atl_gen(train_example_gen)

    # prepare callbacks
    save_file_name = os.path.join(model_dir, '{epoch:02d}.h5')
    save_callback = ModelCheckpoint(save_file_name)

    # fit generator
    with tf.device(gpu):
        model.fit_generator(atlas_gen, 
                            initial_epoch=initial_epoch,
                            epochs=nb_epochs,
                            callbacks=[save_callback],
                            steps_per_epoch=steps_per_epoch,
                            verbose=1)
コード例 #9
0
ファイル: train.py プロジェクト: yiqian-wang/voxelmorph
def train(gpu, data_dir, atlas_file, lr, n_iter, data_loss, model, reg_param,
          batch_size, n_save_iter, model_dir):
    """
    model training function
    :param gpu: integer specifying the gpu to use
    :param data_dir: folder with npz files for each subject.
    :param atlas_file: atlas filename. So far we support npz file with a 'vol' variable
    :param lr: learning rate
    :param n_iter: number of training iterations
    :param data_loss: data_loss: 'mse' or 'ncc
    :param model: either vm1 or vm2 (based on CVPR 2018 paper)
    :param reg_param: the smoothness/reconstruction tradeoff parameter (lambda in CVPR paper)
    :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size
    :param n_save_iter: Optional, default of 500. Determines how many epochs before saving model version.
    :param model_dir: the model directory to save to
    """

    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    device = "cuda"

    # Produce the loaded atlas with dims.:160x192x224.
    atlas_vol = np.load(atlas_file)['vol'][np.newaxis, ..., np.newaxis]
    vol_size = atlas_vol.shape[1:-1]

    # Get all the names of the training data
    train_vol_names = glob.glob(os.path.join(data_dir, '*.npz'))
    random.shuffle(train_vol_names)

    # Prepare the vm1 or vm2 model and send to device
    nf_enc = [16, 32, 32, 32]
    if model == "vm1":
        nf_dec = [32, 32, 32, 32, 8, 8]
    elif model == "vm2":
        nf_dec = [32, 32, 32, 32, 32, 16, 16]
    else:
        raise ValueError("Not yet implemented!")

    model = cvpr2018_net(vol_size, nf_enc, nf_dec)
    model.to(device)

    # Set optimizer and losses
    opt = Adam(model.parameters(), lr=lr)

    sim_loss_fn = losses.ncc_loss if data_loss == "ncc" else losses.mse_loss
    grad_loss_fn = losses.gradient_loss

    # data generator
    train_example_gen = datagenerators.example_gen(train_vol_names, batch_size)

    # set up atlas tensor
    atlas_vol_bs = np.repeat(atlas_vol, batch_size, axis=0)
    input_fixed = torch.from_numpy(atlas_vol_bs).to(device).float()
    input_fixed = input_fixed.permute(0, 4, 1, 2, 3)

    # Training loop.
    for i in range(n_iter):

        # Save model checkpoint
        if i % n_save_iter == 0:
            save_file_name = os.path.join(model_dir, '%d.ckpt' % i)
            torch.save(model.state_dict(), save_file_name)

        # Generate the moving images and convert them to tensors.
        moving_image = next(train_example_gen)[0]
        input_moving = torch.from_numpy(moving_image).to(device).float()
        input_moving = input_moving.permute(0, 4, 1, 2, 3)

        # Run the data through the model to produce warp and flow field
        warp, flow = model(input_moving, input_fixed)

        # Calculate loss
        recon_loss = sim_loss_fn(warp, input_fixed)
        grad_loss = grad_loss_fn(flow)
        loss = recon_loss + reg_param * grad_loss

        print("%d,%f,%f,%f" %
              (i, loss.item(), recon_loss.item(), grad_loss.item()),
              flush=True)

        # Backwards and optimize
        opt.zero_grad()
        loss.backward()
        opt.step()
コード例 #10
0
ファイル: train.py プロジェクト: yellowgardenia/HDAR
def train(src_dir,
          tgt_dir,
          model_dir,
          model_lr_dir,
          lr,
          nb_epochs,
          reg_param,
          steps_per_epoch,
          batch_size,
          load_model_file=None,
          data_loss='ncc',
          initial_epoch=0):
    """
    model training function
    :param data_dir: folder with npz files for each subject.
    :param atlas_file: atlas filename. So far we support npz file with a 'vol' variable
    :param model: either vm1 or vm2 (based on CVPR 2018 paper)
    :param model_dir: the model directory to save to
    :param lr: learning rate
    :param n_iterations: number of training iterations
    :param reg_param: the smoothness/reconstruction tradeoff parameter (lambda in CVPR paper)
    :param steps_per_epoch: frequency with which to save models
    :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size
    :param load_model_file: optional h5 model file to initialize with
    :param data_loss: data_loss: 'mse' or 'ncc
    """

    # prepare data files
    # for the CVPR and MICCAI papers, we have data arranged in train/validate/test folders
    # inside each folder is a /vols/ and a /asegs/ folder with the volumes
    # and segmentations. All of our papers use npz formated data.
    src_vol_names = glob.glob(os.path.join(src_dir, '*.npz'))
    tgt_vol_names = glob.glob(os.path.join(tgt_dir, '*.npz'))
    random.shuffle(src_vol_names)  # shuffle volume list
    random.shuffle(tgt_vol_names)  # shuffle volume list
    assert len(src_vol_names) > 0, "Could not find any training data"

    assert data_loss in [
        'mse', 'ncc'
    ], 'Loss should be one of mse or cc, found %s' % data_loss
    if data_loss == 'ncc':
        data_loss = losses.NCC().loss

        # GPU handling
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    # set_session(tf.Session(config=config))

    vol_size = (56, 56, 56)
    # prepare the model
    src_lr = tf.placeholder(tf.float32, [None, *vol_size, 1],
                            name='input_src_lr')
    tgt_lr = tf.placeholder(tf.float32, [None, *vol_size, 1],
                            name='input_tgt_lr')
    srm_lr = tf.placeholder(tf.float32, [None, *vol_size, 1],
                            name='mask_src_lr')
    attn_lr = tf.placeholder(tf.float32, [None, *vol_size, 1], name='attn_lr')

    src_mr = tf.placeholder(tf.float32, [None, *vol_size, 1],
                            name='input_src_mr')
    tgt_mr = tf.placeholder(tf.float32, [None, *vol_size, 1],
                            name='input_tgt_mr')
    srm_mr = tf.placeholder(tf.float32, [None, *vol_size, 1],
                            name='mask_src_mr')
    df_lr2mr = tf.placeholder(tf.float32, [None, *vol_size, 3],
                              name='df_lr2mr')
    attn_mr = tf.placeholder(tf.float32, [None, *vol_size, 1], name='attn_mr')

    src_hr = tf.placeholder(tf.float32, [None, *vol_size, 1],
                            name='input_src_hr')
    tgt_hr = tf.placeholder(tf.float32, [None, *vol_size, 1],
                            name='input_tgt_hr')
    srm_hr = tf.placeholder(tf.float32, [None, *vol_size, 1],
                            name='mask_src_hr')
    df_mr2hr = tf.placeholder(tf.float32, [None, *vol_size, 3],
                              name='df_mr2hr')
    attn_hr = tf.placeholder(tf.float32, [None, *vol_size, 1], name='attn_hr')

    model_lr = networks.net_lr(src_lr, tgt_lr, srm_lr)
    model_mr = networks.net_mr(src_mr, tgt_mr, srm_mr, df_lr2mr)
    model_hr = networks.net_hr(src_hr, tgt_hr, srm_hr, df_mr2hr)

    # the loss functions
    lr_ncc = data_loss(model_lr[0].outputs, tgt_lr)
    #lr_grd = losses.Grad('l2').loss(model_lr[0].outputs, model_lr[2].outputs)
    lr_grd = losses.Anti_Folding('l2').loss(model_lr[0].outputs,
                                            model_lr[2].outputs)

    cost_lr = lr_ncc + reg_param * lr_grd  # + lr_attn

    mr_ncc = data_loss(model_mr[0].outputs, tgt_mr)
    #mr_grd = losses.Grad('l2').loss(model_mr[0].outputs, model_mr[2].outputs)
    mr_grd = losses.Anti_Folding('l2').loss(model_mr[0].outputs,
                                            model_mr[2].outputs)

    cost_mr = mr_ncc + reg_param * mr_grd

    hr_ncc = data_loss(model_hr[0].outputs, tgt_hr)
    #hr_grd = losses.Grad('l2').loss(model_hr[0].outputs, model_hr[2].outputs)
    hr_grd = losses.Anti_Folding('l2').loss(model_hr[0].outputs,
                                            model_hr[2].outputs)

    cost_hr = hr_ncc + reg_param * hr_grd

    # the training operations
    def get_v(name):
        t_vars = tf.trainable_variables()
        d_vars = [var for var in t_vars if name in var.name]
        return d_vars

    #attn_vars = tl.layers.get_variables_with_name('cbam_1', True, True)
    attn_vars = get_v('cbam_1')
    for a_v in attn_vars:
        print(a_v)

    train_op_lr = tf.train.AdamOptimizer(lr).minimize(cost_lr)

    train_op_mr = tf.train.AdamOptimizer(lr).minimize(cost_mr)
    train_op_hr = tf.train.AdamOptimizer(lr).minimize(cost_hr)

    # data generator
    src_example_gen = datagenerators.example_gen(src_vol_names,
                                                 batch_size=batch_size)
    tgt_example_gen = datagenerators.example_gen(tgt_vol_names,
                                                 batch_size=batch_size)

    data_gen = datagenerators.gen_with_mask(src_example_gen,
                                            tgt_example_gen,
                                            batch_size=batch_size)

    variables_to_restore = tf.contrib.framework.get_variables_to_restore(
        exclude=['net_hr'])
    saver = tf.train.Saver(variables_to_restore)

    #saver = tf.train.Saver(max_to_keep=3)
    # fit generator
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        # load initial weights
        try:
            if load_model_file is not None:
                model_file = tf.train.latest_checkpoint(load_model_file)  #
                saver.restore(sess, model_file)
        except:
            print('No files in', load_model_file)
        saver.save(sess, model_dir + 'dfnet', global_step=0)

        def resize_df(df, zoom):
            df1 = nd.interpolation.zoom(
                df[0, :, :, :, 0], zoom=zoom, mode='nearest', order=3) * zoom[
                    0]  # Cubic: order=3; Bilinear: order=1; Nearest: order=0
            df2 = nd.interpolation.zoom(
                df[0, :, :, :,
                   1], zoom=zoom, mode='nearest', order=3) * zoom[1]
            df3 = nd.interpolation.zoom(
                df[0, :, :, :,
                   2], zoom=zoom, mode='nearest', order=3) * zoom[2]
            dfs = np.stack((df1, df2, df3), axis=3)
            return dfs[np.newaxis, :, :, :]

        class logPrinter(object):
            def __init__(self):
                self.n_batch = 0
                self.total_dice = []
                self.cost = []
                self.ncc = []
                self.grd = []

            def addLog(self, dice, cost, ncc, grd):
                self.n_batch += 1
                self.dice.append(dice)
                self.cost.append(cost)
                self.ncc.append(ncc)
                self.grd.append(grd)

            def output(self):
                dice = np.array(self.dice).mean(axis=0).round(3).tolist()
                cost = np.array(self.cost).mean()
                ncc = np.array(self.ncc).mean()
                grd = np.array(self.grd).mean()
                return dice, cost, ncc, grd, self.n_batch

            def clear(self):
                self.n_batch = 0
                self.dice = []
                self.cost = []
                self.ncc = []
                self.grd = []

        lr_log = logPrinter()
        mr_log = logPrinter()
        hr_log = logPrinter()

        # train low resolution
        # load initial weights
        saver = tf.train.Saver(max_to_keep=1)
        #if model_lr_dir is not None:
        #    model_lr_dir = tf.train.latest_checkpoint(model_lr_dir)  #
        #    print(model_lr_dir)
        #    saver.restore(sess, model_lr_dir)

        nb_epochs = 20  #20#10
        steps_per_epoch = 30 * 29
        for epoch in range(nb_epochs):
            tbar = trange(steps_per_epoch, unit='batch', ncols=100)
            lr_log.clear()
            for i in tbar:
                image, mask = data_gen.__next__()
                global_X, global_atlas = image
                global_X_mask, global_atlas_mask = mask
                global_diff = global_X[0, :, :, :,
                                       0] - global_atlas[0, :, :, :, 0]

                # low resolution
                global_X_64 = nd.interpolation.zoom(global_X[0, :, :, :, 0],
                                                    zoom=(0.25, 0.25, 0.25),
                                                    mode='nearest')
                global_A_64 = nd.interpolation.zoom(global_atlas[0, :, :, :,
                                                                 0],
                                                    zoom=(0.25, 0.25, 0.25),
                                                    mode='nearest')
                global_XM_64 = nd.interpolation.zoom(global_X_mask[0, :, :, :,
                                                                   0],
                                                     zoom=(0.25, 0.25, 0.25),
                                                     mode='nearest',
                                                     order=0)
                global_AM_64 = nd.interpolation.zoom(
                    global_atlas_mask[0, :, :, :, 0],
                    zoom=(0.25, 0.25, 0.25),
                    mode='nearest',
                    order=0)
                global_diff_16 = nd.interpolation.zoom(global_diff,
                                                       zoom=(0.25, 0.25, 0.25),
                                                       mode='nearest')

                global_X_64 = global_X_64[np.newaxis, :, :, :, np.newaxis]
                global_A_64 = global_A_64[np.newaxis, :, :, :, np.newaxis]
                global_XM_64 = global_XM_64[np.newaxis, :, :, :, np.newaxis]
                global_AM_64 = global_AM_64[np.newaxis, :, :, :, np.newaxis]
                global_diff_16 = global_diff_16[np.newaxis, :, :, :,
                                                np.newaxis]

                feed_dict = {
                    src_lr: global_X_64,
                    tgt_lr: global_A_64,
                    srm_lr: global_XM_64,
                    attn_lr: global_diff_16
                }
                err_lr, _ = sess.run([cost_lr, train_op_lr],
                                     feed_dict=feed_dict)
                df_lr, warp_seg, elr_ncc, elr_grad, lr_attn_map, lr_attn_feature = sess.run(
                    [
                        model_lr[2].outputs, model_lr[1].outputs, lr_ncc,
                        lr_grd, model_lr[3], model_lr[4]
                    ],
                    feed_dict=feed_dict)
                # print(df_lr.shape)
                lr_dice, _ = dice(warp_seg[0, :, :, :, 0],
                                  global_AM_64[0, :, :, :, 0],
                                  labels=[0, 10, 150, 250],
                                  nargout=2)
                lr_log.addLog(lr_dice, err_lr, elr_ncc, elr_grad)
                lr_out = lr_log.output()

                tbar.set_description('Epoch %d/%d ### step %i' %
                                     (epoch + 1, nb_epochs, i))
                tbar.set_postfix(lr_dice=lr_out[0],
                                 lr_cost=lr_out[1],
                                 lr_ncc=lr_out[2],
                                 lr_grd=lr_out[3])

            saver.save(sess, model_lr_dir + 'dfnet', global_step=0)
        # train middle resolution
        nb_epochs = 1  #1
        steps_per_epoch = 30 * 29
        for epoch in range(nb_epochs):
            lr_log.clear()
            for lr_step in range(steps_per_epoch):
                image, mask = data_gen.__next__()
                global_X, global_atlas = image
                global_X_mask, global_atlas_mask = mask
                global_diff = global_X[0, :, :, :,
                                       0] - global_atlas[0, :, :, :, 0]

                # low resolution
                global_X_64 = nd.interpolation.zoom(global_X[0, :, :, :, 0],
                                                    zoom=(0.25, 0.25, 0.25),
                                                    mode='nearest')
                global_A_64 = nd.interpolation.zoom(global_atlas[0, :, :, :,
                                                                 0],
                                                    zoom=(0.25, 0.25, 0.25),
                                                    mode='nearest')
                global_XM_64 = nd.interpolation.zoom(global_X_mask[0, :, :, :,
                                                                   0],
                                                     zoom=(0.25, 0.25, 0.25),
                                                     mode='nearest',
                                                     order=0)
                global_AM_64 = nd.interpolation.zoom(
                    global_atlas_mask[0, :, :, :, 0],
                    zoom=(0.25, 0.25, 0.25),
                    mode='nearest',
                    order=0)
                global_diff_16 = nd.interpolation.zoom(global_diff,
                                                       zoom=(0.25, 0.25, 0.25),
                                                       mode='nearest')

                global_X_64 = global_X_64[np.newaxis, :, :, :, np.newaxis]
                global_A_64 = global_A_64[np.newaxis, :, :, :, np.newaxis]
                global_XM_64 = global_XM_64[np.newaxis, :, :, :, np.newaxis]
                global_AM_64 = global_AM_64[np.newaxis, :, :, :, np.newaxis]
                global_diff_16 = global_diff_16[np.newaxis, :, :, :,
                                                np.newaxis]

                feed_dict = {
                    src_lr: global_X_64,
                    tgt_lr: global_A_64,
                    srm_lr: global_XM_64,
                    attn_lr: global_diff_16
                }
                err_lr, _ = sess.run([cost_lr, train_op_lr],
                                     feed_dict=feed_dict)
                df_lr, warp_seg, elr_ncc, elr_grad, lr_attn_map, lr_attn_feature = sess.run(
                    [
                        model_lr[2].outputs, model_lr[1].outputs, lr_ncc,
                        lr_grd, model_lr[3], model_lr[4]
                    ],
                    feed_dict=feed_dict)

                lr_dice, _ = dice(warp_seg[0, :, :, :, 0],
                                  global_AM_64[0, :, :, :, 0],
                                  labels=[0, 10, 150, 250],
                                  nargout=2)
                lr_log.addLog(lr_dice, err_lr, elr_ncc, elr_grad)
                lr_out = lr_log.output()

                print('\nEpoch %d/%d ### step %i' %
                      (epoch + 1, nb_epochs, lr_out[-1]))
                print(
                    '[lr] lr_dice={}, lr_cost={:.3f}, lr_ncc={:.3f}, lr_grd={:.3f}'
                    .format(lr_out[0], lr_out[1], lr_out[2], lr_out[3]))

                # middle part
                df_lr_res2mr = resize_df(df_lr, zoom=(2, 2, 2))

                select_points_lr = patch_selection_attn(lr_attn_map,
                                                        zoom_scales=[8, 8, 8],
                                                        kernel=7,
                                                        mi=10,
                                                        ma=18)
                print(select_points_lr)
                mr_log.clear()

                for sp in select_points_lr:
                    mov_img_112 = global_X[0, sp[0] - 56:sp[0] + 56,
                                           sp[1] - 56:sp[1] + 56,
                                           sp[2] - 56:sp[2] + 56, 0]
                    fix_img_112 = global_atlas[0, sp[0] - 56:sp[0] + 56,
                                               sp[1] - 56:sp[1] + 56,
                                               sp[2] - 56:sp[2] + 56, 0]
                    mov_seg_112 = global_X_mask[0, sp[0] - 56:sp[0] + 56,
                                                sp[1] - 56:sp[1] + 56,
                                                sp[2] - 56:sp[2] + 56, 0]
                    fix_seg_112 = global_atlas_mask[0, sp[0] - 56:sp[0] + 56,
                                                    sp[1] - 56:sp[1] + 56,
                                                    sp[2] - 56:sp[2] + 56, 0]
                    dif_img_112 = global_diff[sp[0] - 56:sp[0] + 56,
                                              sp[1] - 56:sp[1] + 56,
                                              sp[2] - 56:sp[2] + 56]

                    #print(mov_img_112.shape)
                    if fix_img_112.shape != (112, 112, 112):
                        print(mov_img_112.shape)
                        continue
                    fix_112_56 = nd.interpolation.zoom(fix_img_112,
                                                       zoom=(0.5, 0.5, 0.5),
                                                       mode='nearest')
                    mov_112_56 = nd.interpolation.zoom(mov_img_112,
                                                       zoom=(0.5, 0.5, 0.5),
                                                       mode='nearest')
                    fix_112_56m = nd.interpolation.zoom(fix_seg_112,
                                                        zoom=(0.5, 0.5, 0.5),
                                                        mode='nearest',
                                                        order=0)
                    mov_112_56m = nd.interpolation.zoom(mov_seg_112,
                                                        zoom=(0.5, 0.5, 0.5),
                                                        mode='nearest',
                                                        order=0)
                    dif_112_56 = nd.interpolation.zoom(dif_img_112,
                                                       zoom=(0.5, 0.5, 0.5),
                                                       mode='nearest')

                    mid_fix_img = fix_112_56[np.newaxis, :, :, :, np.newaxis]
                    mid_mov_img = mov_112_56[np.newaxis, :, :, :, np.newaxis]
                    mid_fix_seg = fix_112_56m[np.newaxis, :, :, :, np.newaxis]
                    mid_mov_seg = mov_112_56m[np.newaxis, :, :, :, np.newaxis]
                    mid_dif_img = dif_112_56[np.newaxis, :, :, :, np.newaxis]
                    df_mr_feed = df_lr_res2mr[:,
                                              sp[0] // 2 - 28:sp[0] // 2 + 28,
                                              sp[1] // 2 - 28:sp[1] // 2 + 28,
                                              sp[2] // 2 - 28:sp[2] // 2 +
                                              28, :]

                    feed_dict = {
                        src_mr: mid_mov_img,
                        tgt_mr: mid_fix_img,
                        srm_mr: mid_mov_seg,
                        df_lr2mr: df_mr_feed,
                        attn_mr: mid_dif_img
                    }
                    err_mr, _ = sess.run([cost_mr, train_op_mr],
                                         feed_dict=feed_dict)
                    df_mr, warp_seg, emr_ncc, emr_grad = sess.run(
                        [
                            model_mr[2].outputs, model_mr[1].outputs, mr_ncc,
                            mr_grd
                        ],
                        feed_dict=feed_dict)

                    mr_dice, _ = dice(warp_seg[0, :, :, :, 0],
                                      mid_fix_seg[0, :, :, :, 0],
                                      labels=[0, 10, 150, 250],
                                      nargout=2)
                    mr_log.addLog(mr_dice, err_mr, emr_ncc, emr_grad)
                    mr_out = mr_log.output()

                    # print('  Epoch %d/%d ### step %i' % (epoch+1, nb_epochs, mr_out[-1]))
                    print(
                        '  [mr] {}/{} mr_dice={}, mr_cost={:.3f}, mr_ncc={:.3f}, mr_grd={:.3f}'
                        .format(mr_out[-1], len(select_points_lr), mr_out[0],
                                mr_out[1], mr_out[2], mr_out[3]))

            saver.save(sess, model_dir + 'dfnet', global_step=0)

        # train high resolution
        nb_epochs = 1
        steps_per_epoch = 300
        for epoch in range(nb_epochs):
            lr_log.clear()
            for lr_step in range(steps_per_epoch):
                image, mask = data_gen.__next__()
                global_X, global_atlas = image
                global_X_mask, global_atlas_mask = mask
                global_diff = global_X[0, :, :, :,
                                       0] - global_atlas[0, :, :, :, 0]

                # low resolution
                global_X_64 = nd.interpolation.zoom(global_X[0, :, :, :, 0],
                                                    zoom=(0.25, 0.25, 0.25),
                                                    mode='nearest')
                global_A_64 = nd.interpolation.zoom(global_atlas[0, :, :, :,
                                                                 0],
                                                    zoom=(0.25, 0.25, 0.25),
                                                    mode='nearest')
                global_XM_64 = nd.interpolation.zoom(global_X_mask[0, :, :, :,
                                                                   0],
                                                     zoom=(0.25, 0.25, 0.25),
                                                     mode='nearest',
                                                     order=0)
                global_AM_64 = nd.interpolation.zoom(
                    global_atlas_mask[0, :, :, :, 0],
                    zoom=(0.25, 0.25, 0.25),
                    mode='nearest',
                    order=0)
                global_diff_16 = nd.interpolation.zoom(global_diff,
                                                       zoom=(0.25, 0.25, 0.25),
                                                       mode='nearest')

                global_X_64 = global_X_64[np.newaxis, :, :, :, np.newaxis]
                global_A_64 = global_A_64[np.newaxis, :, :, :, np.newaxis]
                global_XM_64 = global_XM_64[np.newaxis, :, :, :, np.newaxis]
                global_AM_64 = global_AM_64[np.newaxis, :, :, :, np.newaxis]
                global_diff_16 = global_diff_16[np.newaxis, :, :, :,
                                                np.newaxis]

                feed_dict = {
                    src_lr: global_X_64,
                    tgt_lr: global_A_64,
                    srm_lr: global_XM_64,
                    attn_lr: global_diff_16
                }
                err_lr, _ = sess.run([cost_lr, train_op_lr],
                                     feed_dict=feed_dict)
                df_lr, warp_seg, elr_ncc, elr_grad, lr_attn_map, lr_attn_feature = sess.run(
                    [
                        model_lr[2].outputs, model_lr[1].outputs, lr_ncc,
                        lr_grd, model_lr[3], model_lr[4]
                    ],
                    feed_dict=feed_dict)

                lr_dice, _ = dice(warp_seg[0, :, :, :, 0],
                                  global_AM_64[0, :, :, :, 0],
                                  labels=[0, 10, 150, 250],
                                  nargout=2)
                lr_log.addLog(lr_dice, err_lr, elr_ncc, elr_grad)
                lr_out = lr_log.output()

                print('\nEpoch %d/%d ### step %i' %
                      (epoch + 1, nb_epochs, lr_out[-1]))
                print(
                    '[lr] lr_dice={}, lr_cost={:.3f}, lr_ncc={:.3f}, lr_grd={:.3f}'
                    .format(lr_out[0], lr_out[1], lr_out[2], lr_out[3]))

                # middle part
                df_lr_res2mr = resize_df(df_lr, zoom=(2, 2, 2))

                select_points_lr = patch_selection_attn(lr_attn_map,
                                                        zoom_scales=[8, 8, 8],
                                                        kernel=7,
                                                        mi=10,
                                                        ma=18)
                print(select_points_lr)
                mr_log.clear()

                for sp in select_points_lr:
                    mov_img_112 = global_X[0, sp[0] - 56:sp[0] + 56,
                                           sp[1] - 56:sp[1] + 56,
                                           sp[2] - 56:sp[2] + 56, 0]
                    fix_img_112 = global_atlas[0, sp[0] - 56:sp[0] + 56,
                                               sp[1] - 56:sp[1] + 56,
                                               sp[2] - 56:sp[2] + 56, 0]
                    mov_seg_112 = global_X_mask[0, sp[0] - 56:sp[0] + 56,
                                                sp[1] - 56:sp[1] + 56,
                                                sp[2] - 56:sp[2] + 56, 0]
                    fix_seg_112 = global_atlas_mask[0, sp[0] - 56:sp[0] + 56,
                                                    sp[1] - 56:sp[1] + 56,
                                                    sp[2] - 56:sp[2] + 56, 0]
                    dif_img_112 = global_diff[sp[0] - 56:sp[0] + 56,
                                              sp[1] - 56:sp[1] + 56,
                                              sp[2] - 56:sp[2] + 56]

                    #print(mov_img_112.shape)
                    if fix_img_112.shape != (112, 112, 112):
                        print(mov_img_112.shape)
                        continue
                    fix_112_56 = nd.interpolation.zoom(fix_img_112,
                                                       zoom=(0.5, 0.5, 0.5),
                                                       mode='nearest')
                    mov_112_56 = nd.interpolation.zoom(mov_img_112,
                                                       zoom=(0.5, 0.5, 0.5),
                                                       mode='nearest')
                    fix_112_56m = nd.interpolation.zoom(fix_seg_112,
                                                        zoom=(0.5, 0.5, 0.5),
                                                        mode='nearest',
                                                        order=0)
                    mov_112_56m = nd.interpolation.zoom(mov_seg_112,
                                                        zoom=(0.5, 0.5, 0.5),
                                                        mode='nearest',
                                                        order=0)
                    dif_112_56 = nd.interpolation.zoom(dif_img_112,
                                                       zoom=(0.5, 0.5, 0.5),
                                                       mode='nearest')

                    mid_fix_img = fix_112_56[np.newaxis, :, :, :, np.newaxis]
                    mid_mov_img = mov_112_56[np.newaxis, :, :, :, np.newaxis]
                    mid_fix_seg = fix_112_56m[np.newaxis, :, :, :, np.newaxis]
                    mid_mov_seg = mov_112_56m[np.newaxis, :, :, :, np.newaxis]
                    mid_dif_img = dif_112_56[np.newaxis, :, :, :, np.newaxis]
                    df_mr_feed = df_lr_res2mr[:,
                                              sp[0] // 2 - 28:sp[0] // 2 + 28,
                                              sp[1] // 2 - 28:sp[1] // 2 + 28,
                                              sp[2] // 2 - 28:sp[2] // 2 +
                                              28, :]

                    feed_dict = {
                        src_mr: mid_mov_img,
                        tgt_mr: mid_fix_img,
                        srm_mr: mid_mov_seg,
                        df_lr2mr: df_mr_feed,
                        attn_mr: mid_dif_img
                    }
                    err_mr, _ = sess.run([cost_mr, train_op_mr],
                                         feed_dict=feed_dict)
                    df_mr, warp_seg, emr_ncc, emr_grad, mr_attn_map, mr_attn_feature = sess.run(
                        [
                            model_mr[2].outputs, model_mr[1].outputs, mr_ncc,
                            mr_grd, model_mr[3], model_mr[4]
                        ],
                        feed_dict=feed_dict)

                    mr_dice, _ = dice(warp_seg[0, :, :, :, 0],
                                      mid_fix_seg[0, :, :, :, 0],
                                      labels=[0, 10, 150, 250],
                                      nargout=2)
                    mr_log.addLog(mr_dice, err_mr, emr_ncc, emr_grad)
                    mr_out = mr_log.output()

                    # print('  Epoch %d/%d ### step %i' % (epoch+1, nb_epochs, mr_out[-1]))
                    print(
                        '  [mr] {}/{} mr_dice={}, mr_cost={:.3f}, mr_ncc={:.3f}, mr_grd={:.3f}'
                        .format(mr_out[-1], len(select_points_lr), mr_out[0],
                                mr_out[1], mr_out[2], mr_out[3]))

                    # high part
                    df_mr_res2hr = resize_df(df_mr, zoom=(2, 2, 2))
                    hr_log.clear()
                    select_points_mr = patch_selection_attn(
                        mr_attn_map,
                        zoom_scales=[4, 4, 4],
                        kernel=7,
                        mi=8,
                        ma=20)
                    print(30 * '-')
                    print('High Part')
                    print(select_points_mr)
                    for spm in select_points_mr:
                        fix_img_56 = fix_img_112[spm[0] - 28:spm[0] + 28,
                                                 spm[1] - 28:spm[1] + 28,
                                                 spm[2] - 28:spm[2] + 28]
                        mov_img_56 = mov_img_112[spm[0] - 28:spm[0] + 28,
                                                 spm[1] - 28:spm[1] + 28,
                                                 spm[2] - 28:spm[2] + 28]
                        fix_seg_56 = fix_seg_112[spm[0] - 28:spm[0] + 28,
                                                 spm[1] - 28:spm[1] + 28,
                                                 spm[2] - 28:spm[2] + 28]
                        mov_seg_56 = mov_seg_112[spm[0] - 28:spm[0] + 28,
                                                 spm[1] - 28:spm[1] + 28,
                                                 spm[2] - 28:spm[2] + 28]
                        dif_img_56 = dif_img_112[spm[0] - 28:spm[0] + 28,
                                                 spm[1] - 28:spm[1] + 28,
                                                 spm[2] - 28:spm[2] + 28]
                        if fix_img_56.shape != (56, 56, 56):
                            continue

                        hig_fix_img = fix_img_56[np.newaxis, :, :, :,
                                                 np.newaxis]
                        hig_mov_img = mov_img_56[np.newaxis, :, :, :,
                                                 np.newaxis]
                        hig_fix_seg = fix_seg_56[np.newaxis, :, :, :,
                                                 np.newaxis]
                        hig_mov_seg = mov_seg_56[np.newaxis, :, :, :,
                                                 np.newaxis]
                        hig_dif_img = dif_img_56[np.newaxis, :, :, :,
                                                 np.newaxis]

                        df_hr_feed = df_mr_res2hr[:, spm[0] - 28:spm[0] + 28,
                                                  spm[1] - 28:spm[1] + 28,
                                                  spm[2] - 28:spm[2] + 28, :]

                        feed_dict = {
                            src_hr: hig_mov_img,
                            tgt_hr: hig_fix_img,
                            srm_hr: hig_mov_seg,
                            df_mr2hr: df_hr_feed,
                            attn_hr: hig_dif_img
                        }
                        err_hr, _ = sess.run([cost_hr, train_op_hr],
                                             feed_dict=feed_dict)
                        df_hr, warp_seg, ehr_ncc, ehr_grad = sess.run(
                            [
                                model_hr[2].outputs, model_hr[1].outputs,
                                hr_ncc, hr_grd
                            ],
                            feed_dict=feed_dict)

                        hr_dice, _ = dice(warp_seg[0, :, :, :, 0],
                                          hig_fix_seg[0, :, :, :, 0],
                                          labels=[0, 10, 150, 250],
                                          nargout=2)
                        hr_log.addLog(hr_dice, err_hr, ehr_ncc, ehr_grad)
                        hr_out = hr_log.output()

                        # print('  Epoch %d/%d ### step %i' % (epoch+1, nb_epochs, mr_out[-1]))
                        print(
                            '    [hr] {}/{} hr_dice={}, hr_cost={:.3f}, hr_ncc={:.3f}, hr_grd={:.3f}'
                            .format(hr_out[-1], len(select_points_mr),
                                    hr_out[0], hr_out[1], hr_out[2],
                                    hr_out[3]))

                saver.save(sess, model_dir + 'dfnet', global_step=lr_step)
コード例 #11
0
def train(gpu, data_dir, size, atlas_dir, lr, n_iter, data_loss, model,
          reg_param, batch_size, n_save_iter, model_dir, nr_val_data):
    """
    model training function
    :param gpu: integer specifying the gpu to use
    :param data_dir: folder with npz files for each subject.
    :param size: int desired size of the volumes: [size,size,size]
    :param atlas_dir: direction to atlas folder
    :param lr: learning rate
    :param n_iter: number of training iterations
    :param data_loss: data_loss: 'mse' or 'ncc
    :param model: either vm1 or vm2 (based on CVPR 2018 paper)
    :param reg_param: the smoothness/reconstruction tradeoff parameter (lambda in CVPR paper)
    :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size
    :param n_save_iter: Optional, default of 500. Determines how many epochs before saving model version.
    :param model_dir: the model directory to save to
    :param nr_val_data: number of validation examples that should be separated from the training data
    """

    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    device = "cuda"
    vol_size = np.array([size, size, size])
    # Get all the names of the training data
    vol_names = glob.glob(os.path.join(data_dir, '*.nii'))
    #random.shuffle(vol_names)
    #test_vol_names =  vol_names[-nr_val_data:]
    test_vol_names = vol_names[:nr_val_data]
    #test_vol_names = [i for i in test_vol_names if "L2-L4" in i]
    print(
        'these volumes are separated from the data and serve as validation data : '
    )
    print(test_vol_names)

    #train_vol_names = vol_names[:-nr_val_data]
    train_vol_names = vol_names[nr_val_data:]
    #train_vol_names = [i for i in train_vol_names if "L2-L4" in i]

    random.shuffle(train_vol_names)
    writer = SummaryWriter(get_outputs_path())

    # Prepare the vm1 or vm2 model and send to device
    nf_enc = [16, 32, 32, 32]
    if model == "vm1":
        nf_dec = [32, 32, 32, 32, 8, 8]
    elif model == "vm2":
        nf_dec = [32, 32, 32, 32, 32, 16, 16]
    else:
        raise ValueError("Not yet implemented!")

    model = cvpr2018_net(vol_size, nf_enc, nf_dec)
    model.to(device)

    # Set optimizer and losses
    opt = Adam(model.parameters(), lr=lr)

    sim_loss_fn = losses.ncc_loss if data_loss == "ncc" else losses.mse_loss
    grad_loss_fn = losses.gradient_loss

    # data generator
    train_example_gen = datagenerators.example_gen(train_vol_names, atlas_dir,
                                                   size, batch_size)

    # Training loop.
    for i in range(n_iter):

        # Save model checkpoint and plot validation score
        if i % n_save_iter == 0:
            save_file_name = os.path.join(model_dir, '%d.ckpt' % i)
            torch.save(model.state_dict(), save_file_name)
            # load validation data
            val_example_gen = datagenerators.example_gen(
                test_vol_names, atlas_dir, size, 4)
            val_data = next(val_example_gen)
            val_fixed = torch.from_numpy(val_data[1]).to(device).float()
            val_fixed = val_fixed.permute(0, 4, 1, 2, 3)
            val_moving = torch.from_numpy(val_data[0]).to(device).float()
            val_moving = val_moving.permute(0, 4, 1, 2, 3)

            #create validation data for the model
            val_warp, val_flow = model(val_moving, val_fixed)

            #calculte validation score
            val_recon_loss = sim_loss_fn(val_warp, val_fixed)
            val_grad_loss = grad_loss_fn(val_flow)
            val_loss = val_recon_loss + reg_param * val_grad_loss

            #tensorboard
            writer.add_scalar('Loss/Test', val_loss, i)

            #prints
            print('validation')
            print("%d,%f,%f,%f" % (i, val_loss.item(), val_recon_loss.item(),
                                   val_grad_loss.item()),
                  flush=True)

        # Generate the moving images and convert them to tensors.

        data_for_network = next(train_example_gen)
        input_fixed = torch.from_numpy(data_for_network[1]).to(device).float()
        input_fixed = input_fixed.permute(0, 4, 1, 2, 3)
        input_moving = torch.from_numpy(data_for_network[0]).to(device).float()
        input_moving = input_moving.permute(0, 4, 1, 2, 3)

        # Run the data through the model to produce warp and flow field
        warp, flow = model(input_moving, input_fixed)
        print("warp_and_flow_field")
        print(warp.size())
        print(flow.size())

        # Calculate loss
        recon_loss = sim_loss_fn(warp, input_fixed)
        grad_loss = grad_loss_fn(flow)
        loss = recon_loss + reg_param * grad_loss

        #tensorboard
        writer.add_scalar('Loss/Train', loss, i)
        print("%d,%f,%f,%f" %
              (i, loss.item(), recon_loss.item(), grad_loss.item()),
              flush=True)

        # Backwards and optimize
        opt.zero_grad()
        loss.backward()
        opt.step()
コード例 #12
0
def train(data_dir,
          atlas_file,
          model,
          model_name,
          gpu_id,
          lr,
          nb_epochs,
          reg_param,
          steps_per_epoch,
          batch_size,
          load_model_file,
          data_loss,
          initial_epoch=0):
    """
    model training function
    :param data_dir: folder with npz files for each subject.
    :param atlas_file: atlas filename. So far we support npz file with a 'vol' variable
    :param model: either vm1 or vm2 (based on CVPR 2018 paper)
    :param model_dir: the model directory to save to
    :param gpu_id: integer specifying the gpu to use
    :param lr: learning rate
    :param n_iterations: number of training iterations
    :param reg_param: the smoothness/reconstruction tradeoff parameter (lambda in CVPR paper)
    :param steps_per_epoch: frequency with which to save models
    :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size
    :param load_model_file: optional h5 model file to initialize with
    :param data_loss: data_loss: 'mse' or 'ncc
    """

    # load atlas from provided files. The atlas we used is 160x192x224.
    # atlas_vol = np.load(atlas_file)['vol'][np.newaxis, ..., np.newaxis]
    atlas_vol = nib.load(atlas_file).get_data()[np.newaxis, ..., np.newaxis]
    vol_size = atlas_vol.shape[1:-1]

    # prepare data files
    # for the CVPR and MICCAI papers, we have data arranged in train/validate/test folders
    # inside each folder is a /vols/ and a /asegs/ folder with the volumes
    # and segmentations. All of our papers use npz formated data.
    train_vol_names = glob.glob(os.path.join(data_dir, '*.npz'))
    random.shuffle(train_vol_names)  # shuffle volume list
    assert len(train_vol_names) > 0, "Could not find any training data"

    # UNET filters for voxelmorph-1 and voxelmorph-2,
    # these are architectures presented in CVPR 2018
    nf_enc = [16, 32, 32, 32]
    if model == 'vm1':
        nf_dec = [32, 32, 32, 32, 8, 8]
    elif model == 'vm2':
        nf_dec = [32, 32, 32, 32, 32, 16, 16]
    else:  # 'vm2double':
        nf_enc = [f * 2 for f in nf_enc]
        nf_dec = [f * 2 for f in [32, 32, 32, 32, 32, 16, 16]]

    assert data_loss in [
        'mse', 'cc', 'ncc'
    ], 'Loss should be one of mse or cc, found %s' % data_loss
    if data_loss in ['ncc', 'cc']:
        data_loss = losses.NCC().loss

    model_dir = "../models/" + model_name
    # prepare model folder
    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    # GPU handling
    gpu = '/gpu:%d' % gpu_id
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    # prepare the model
    with tf.device(gpu):
        # prepare the model
        # in the CVPR layout, the model takes in [image_1, image_2] and outputs [warped_image_1, flow]
        # in the experiments, we use image_2 as atlas
        model = networks.cvpr2018_net(vol_size, nf_enc, nf_dec)

        # load initial weights
        if load_model_file is not None and load_model_file != '':
            print('loading', load_model_file)
            model.load_weights(load_model_file)

        # save first iteration
        model.save(os.path.join(model_dir, '%02d.h5' % initial_epoch))

    # data generator
    # nb_gpus = len(gpu_id.split(','))
    # assert np.mod(batch_size, nb_gpus) == 0, \
    #     'batch_size should be a multiple of the nr. of gpus. ' + \
    #     'Got batch_size %d, %d gpus' % (batch_size, nb_gpus)
    nb_gpus = 1

    train_example_gen = datagenerators.example_gen(train_vol_names,
                                                   batch_size=batch_size)
    atlas_vol_bs = np.repeat(atlas_vol, batch_size, axis=0)
    cvpr2018_gen = datagenerators.cvpr2018_gen(train_example_gen,
                                               atlas_vol_bs,
                                               batch_size=batch_size)

    # prepare callbacks
    save_file_name = os.path.join(model_dir, '{epoch:02d}.h5')

    # fit generator
    with tf.device(gpu):

        # multi-gpu support
        if nb_gpus > 1:
            save_callback = nrn_gen.ModelCheckpointParallel(save_file_name)
            mg_model = multi_gpu_model(model, gpus=nb_gpus)

        # single-gpu
        else:
            save_callback = ModelCheckpoint(save_file_name)
            mg_model = model

        # compile
        mg_model.compile(optimizer=Adam(lr=lr),
                         loss=[data_loss, losses.Grad('l2').loss],
                         loss_weights=[1.0, reg_param])

        # fit
        mg_model.fit_generator(cvpr2018_gen,
                               initial_epoch=initial_epoch,
                               epochs=nb_epochs,
                               callbacks=[save_callback],
                               steps_per_epoch=steps_per_epoch,
                               verbose=1)
コード例 #13
0
def train(model_dir,
          gpu_id,
          lr,
          n_iterations,
          alpha,
          image_sigma,
          model_save_iter,
          batch_size=1):
    """
    model training function
    :param model_dir: model folder to save to
    :param gpu_id: integer specifying the gpu to use
    :param lr: learning rate
    :param n_iterations: number of training iterations
    :param alpha: the alpha, the scalar in front of the smoothing laplacian, in MICCAI paper
    :param image_sigma: the image sigma in MICCAI paper
    :param model_save_iter: frequency with which to save models
    :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size
    """

    # prepare model folder
    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)
    print(model_dir)

    # gpu handling
    gpu = '/gpu:' + str(gpu_id)
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    # Diffeomorphic network architecture used in MICCAI 2018 paper
    nf_enc = [16, 32, 32, 32]
    nf_dec = [32, 32, 32, 32, 16, 3]

    # prepare the model
    # in the CVPR layout, the model takes in [image_1, image_2] and outputs [warped_image_1, velocity_stats]
    # in the experiments, we use image_2 as atlas
    with tf.device(gpu):
        # miccai 2018 used xy indexing.
        model = networks.miccai2018_net(vol_size,
                                        nf_enc,
                                        nf_dec,
                                        use_miccai_int=True,
                                        indexing='xy')

        # compile
        model_losses = [losses.kl_l2loss(image_sigma), losses.kl_loss(alpha)]
        model.compile(optimizer=Adam(lr=lr), loss=model_losses)

        # save first iteration
        model.save(os.path.join(model_dir, str(0) + '.h5'))

    train_example_gen = datagenerators.example_gen(train_vol_names)
    zeros = np.zeros((1, *vol_size, 3))

    # train. Note: we use train_on_batch and design out own print function as this has enabled
    # faster development and debugging, but one could also use fit_generator and Keras callbacks.
    for step in range(1, n_iterations):

        # get_data
        X = next(train_example_gen)[0]

        # train
        with tf.device(gpu):
            train_loss = model.train_on_batch([X, atlas_vol],
                                              [atlas_vol, zeros])

        if not isinstance(train_loss, list):
            train_loss = [train_loss]

        # print
        print_loss(step, 0, train_loss)

        # save model
        with tf.device(gpu):
            if (step % model_save_iter == 0) or step < 10:
                model.save(os.path.join(model_dir, str(step) + '.h5'))
コード例 #14
0
def train(
        data_dir,
        val_data_dir,
        atlas_file,
        val_atlas_file,
        model,
        model_dir,
        gpu_id,
        lr,
        nb_epochs,
        reg_param,
        gama_param,
        steps_per_epoch,
        batch_size,
        load_model_file,
        data_loss,
        seg_dir=None,  # one file
        val_seg_dir=None,
        Sf_file=None,  # one file
        val_Sf_file=None,
        auxi_label=None,
        initial_epoch=0):
    """
    model training function
    :param data_dir: folder with npz files for each subject.
    :param atlas_file: atlas filename. So far we support npz file with a 'vol' variable
    :param model: either vm1 or vm2 (based on CVPR 2018 paper)
    :param model_dir: the model directory to save to
    :param gpu_id: integer specifying the gpu to use
    :param lr: learning rate
    :param n_iterations: number of training iterations
    :param reg_param: the smoothness/reconstruction tradeoff parameter (lambda in CVPR paper)
    :param steps_per_epoch: frequency with which to save models
    :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size
    :param load_model_file: optional h5 model file to initialize with
    :param data_loss: 'mse' or 'ncc
    :param auxi_label: whether to use auxiliary informmation during the training
    """

    # load atlas from provided files. The atlas we used is 160x192x224.
    # atlas_file = 'D:/voxel/data/t064.tif'
    atlas = Image.open(atlas_file)  # is a TiffImageFile _size is (628, 690)
    atlas_vol = np.array(atlas)[
        np.newaxis, ..., np.newaxis]  # is a ndarray, shape is (1, 690, 628, 1)
    # new = Image.fromarray(X) new.size is (628, 690)
    vol_size = atlas_vol.shape[1:-1]  # (690, 628)
    print(vol_size)

    val_atlas = Image.open(
        val_atlas_file)  # is a TiffImageFile _size is (628, 690)
    val_atlas_vol = np.array(val_atlas)[
        np.newaxis, ..., np.newaxis]  # is a ndarray, shape is (1, 690, 628, 1)
    # new = Image.fromarray(X) new.size is (628, 690)
    val_vol_size = val_atlas_vol.shape[1:-1]  # (690, 628)
    print(val_vol_size)

    Sm = Image.open(seg_dir)  # is a TiffImageFile _size is (628, 690)
    Sm_ = np.array(Sm)[np.newaxis, ..., np.newaxis]

    val_Sm = Image.open(val_seg_dir)  # is a TiffImageFile _size is (628, 690)
    val_Sm_ = np.array(val_Sm)[np.newaxis, ..., np.newaxis]

    # prepare data files
    # for the CVPR and MICCAI papers, we have data arranged in train/validate/test folders
    # inside each folder is a /vols/ and a /asegs/ folder with the volumes
    # and segmentations. All of our papers use npz formated data.
    # data_dir = D:/voxel/data/01
    train_vol_names = data_dir  # glob.glob(os.path.join(data_dir, '*.tif'))   # is a list contain file path(name)
    # random.shuffle(train_vol_names)  # shuffle volume list    tif
    assert len(train_vol_names) > 0, "Could not find any training data"

    val_vol_names = val_data_dir  # glob.glob(os.path.join(data_dir, '*.tif'))   # is a list contain file path(name)
    # random.shuffle(train_vol_names)  # shuffle volume list    tif
    assert len(val_vol_names) > 0, "Could not find any training data"

    # UNET filters for voxelmorph-1 and voxelmorph-2,
    # these are architectures presented in CVPR 2018
    nf_enc = [16, 32, 32, 32]
    if model == 'vm1':
        nf_dec = [32, 32, 32, 32, 8, 8]
    elif model == 'vm2':
        nf_dec = [32, 32, 32, 32, 32, 16, 16]
    else:  # 'vm2double':
        nf_enc = [f * 2 for f in nf_enc]
        nf_dec = [f * 2 for f in [32, 32, 32, 32, 32, 16, 16]]

    assert data_loss in [
        'mse', 'cc', 'ncc'
    ], 'Loss should be one of mse or cc, found %s' % data_loss
    if data_loss in ['ncc', 'cc']:
        data_loss = losses.NCC().loss

    if Sf_file is not None:
        Sf = Image.open(Sf_file)
        Sf_ = np.array(Sf)[np.newaxis, ..., np.newaxis]

    if val_Sf_file is not None:
        val_Sf = Image.open(val_Sf_file)
        val_Sf_ = np.array(val_Sf)[np.newaxis, ..., np.newaxis]

        # prepare model folder
    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    # GPU handling
    gpu = '/gpu:%d' % 0  # gpu_id
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))
    #gpu = gpu_id

    # data generator
    nb_gpus = len(gpu_id.split(','))  # 1
    assert np.mod(batch_size, nb_gpus) == 0, \
        'batch_size should be a multiple of the nr. of gpus. ' + \
        'Got batch_size %d, %d gpus' % (batch_size, nb_gpus)

    train_example_gen = datagenerators.example_gen(
        train_vol_names,
        batch_size=batch_size)  # it is a list contain a ndarray
    atlas_vol_bs = np.repeat(
        atlas_vol, batch_size,
        axis=0)  # is a ndarray, if batch_size is 2, shape is (2, 690, 628, 1)
    cvpr2018_gen = datagenerators.cvpr2018_gen(train_example_gen,
                                               atlas_vol_bs,
                                               batch_size=batch_size)

    val_example_gen = datagenerators.example_gen(
        val_vol_names, batch_size=batch_size)  # it is a list contain a ndarray
    val_atlas_vol_bs = np.repeat(
        val_atlas_vol, batch_size,
        axis=0)  # is a ndarray, if batch_size is 2, shape is (2, 690, 628, 1)
    val_cvpr2018_gen = datagenerators.cvpr2018_gen(val_example_gen,
                                                   val_atlas_vol_bs,
                                                   batch_size=batch_size)

    # prepare the model
    with tf.device(gpu):
        sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
        # prepare the model
        # in the CVPR layout, the model takes in [image_1, image_2] and outputs [warped_image_1, flow]
        # in the experiments, we use image_2 as atlas

        model = networks.cvpr2018_net(vol_size, nf_enc, nf_dec)

        # load initial weights
        if load_model_file is not None:
            print('loading', load_model_file)
            model.load_weights(load_model_file)

        # save first iteration
        model.save(os.path.join(model_dir, '%02d.h5' % initial_epoch))

        # if auxi_label is not None:
        #     print('yes')
        #     loss_model= [data_loss, losses.Grad('l2').loss, losses.Lseg()._lseg(Sf_) ]    ##########################
        #     loss_weight= [1.0, reg_param, gama_param]
        # else:
        loss_model = [
            data_loss,
            losses.Grad(gama_param, Sf_, Sm_, penalty='l2').loss
        ]  # real gama: reg_param*gama_param
        loss_weight = [1.0, reg_param]

        # reg_param_tensor = tf.constant(5, dtype=tf.float32)
        metrics_2 = losses.Grad(gama_param,
                                val_Sf_,
                                val_Sm_,
                                penalty='l2',
                                flag_vali=True).loss  # reg_param

    # prepare callbacks
    save_file_name = os.path.join(model_dir, '{epoch:02d}.h5')

    # fit generator
    with tf.device(gpu):

        # multi-gpu support
        if nb_gpus > 1:
            save_callback = nrn_gen.ModelCheckpointParallel(save_file_name)
            mg_model = multi_gpu_model(model, gpus=nb_gpus)

        # single-gpu
        else:
            save_callback = ModelCheckpoint(save_file_name)
            mg_model = model

        # compile
        mg_model.compile(optimizer=Adam(lr=lr),
                         loss=loss_model,
                         loss_weights=loss_weight,
                         metrics={'flow': metrics_2})

        # fit
        history = mg_model.fit_generator(cvpr2018_gen,
                                         initial_epoch=initial_epoch,
                                         epochs=nb_epochs,
                                         callbacks=[save_callback],
                                         steps_per_epoch=steps_per_epoch,
                                         validation_data=val_cvpr2018_gen,
                                         validation_steps=1,
                                         verbose=2)

        # plot

        print('model', mg_model.metrics_names)
        print('keys()', history.history.keys())

        # print(metrics.name)

        plt.plot(history.history['loss'])
        # plt.plot(history.history['val_spatial_transformer_1_loss'])
        plt.title('cvpr_auxi_loss')
        plt.ylabel('loss')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Validation'])
        plt.show()
コード例 #15
0
def train(data_dir,
          atlas_file,
          model_dir,
          gpu_id,
          lr,
          nb_epochs,
          prior_lambda,
          image_sigma,
          steps_per_epoch,
          batch_size,
          load_model_file,
          bidir,
          initial_epoch=0):
    """
    model training function
    :param data_dir: folder with npz files for each subject.
    :param atlas_file: atlas filename. So far we support npz file with a 'vol' variable
    :param model_dir: model folder to save to
    :param gpu_id: integer specifying the gpu to use
    :param lr: learning rate
    :param nb_epochs: number of training iterations
    :param prior_lambda: the prior_lambda, the scalar in front of the smoothing laplacian, in MICCAI paper
    :param image_sigma: the image sigma in MICCAI paper
    :param steps_per_epoch: frequency with which to save models
    :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size
    :param load_model_file: optional h5 model file to initialize with
    :param bidir: logical whether to use bidirectional cost function
    """

    # load atlas from provided files. The atlas we used is 160x192x224.
    atlas_vol = np.load(atlas_file)['vol'][np.newaxis, ..., np.newaxis]
    vol_size = atlas_vol.shape[1:-1]
    # prepare data files
    # for the CVPR and MICCAI papers, we have data arranged in train/validate/test folders
    # inside each folder is a /vols/ and a /asegs/ folder with the volumes
    # and segmentations. All of our papers use npz formated data.
    train_vol_names = glob.glob(os.path.join(data_dir, '*.npz'))
    random.shuffle(train_vol_names)  # shuffle volume list
    assert len(train_vol_names) > 0, "Could not find any training data"

    # Diffeomorphic network architecture used in MICCAI 2018 paper
    nf_enc = [16, 32, 32, 32]
    nf_dec = [32, 32, 32, 32, 16, 3]

    # prepare model folder
    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    # gpu handling
    gpu = '/gpu:%d' % 0  # gpu_id
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    # prepare the model
    with tf.device(gpu):
        # the MICCAI201 model takes in [image_1, image_2] and outputs [warped_image_1, velocity_stats]
        # in these experiments, we use image_2 as atlas
        model = networks.miccai2018_net(vol_size, nf_enc, nf_dec, bidir=bidir)

        # load initial weights
        if load_model_file is not None and load_model_file != "":
            model.load_weights(load_model_file)

        # save first iteration
        model.save(os.path.join(model_dir, '%02d.h5' % initial_epoch))

        # compile
        # note: best to supply vol_shape here than to let tf figure it out.
        flow_vol_shape = model.outputs[-1].shape[1:-1]
        loss_class = losses.Miccai2018(image_sigma,
                                       prior_lambda,
                                       flow_vol_shape=flow_vol_shape)
        if bidir:
            model_losses = [
                loss_class.recon_loss, loss_class.recon_loss,
                loss_class.kl_loss
            ]
            loss_weights = [0.5, 0.5, 1]
        else:
            model_losses = [loss_class.recon_loss, loss_class.kl_loss]
            loss_weights = [1, 1]

    # data generator
    nb_gpus = len(gpu_id.split(','))
    assert np.mod(batch_size, nb_gpus) == 0, \
        'batch_size should be a multiple of the nr. of gpus. ' + \
        'Got batch_size %d, %d gpus' % (batch_size, nb_gpus)

    train_example_gen = datagenerators.example_gen(train_vol_names,
                                                   batch_size=batch_size)
    atlas_vol_bs = np.repeat(atlas_vol, batch_size, axis=0)
    miccai2018_gen = datagenerators.miccai2018_gen(train_example_gen,
                                                   atlas_vol_bs,
                                                   batch_size=batch_size,
                                                   bidir=bidir)

    # prepare callbacks
    save_file_name = os.path.join(model_dir, '{epoch:02d}.h5')

    # fit generator
    with tf.device(gpu):

        # multi-gpu support
        if nb_gpus > 1:
            save_callback = nrn_gen.ModelCheckpointParallel(save_file_name)
            mg_model = multi_gpu_model(model, gpus=nb_gpus)

        # single gpu
        else:
            save_callback = ModelCheckpoint(save_file_name)
            mg_model = model

        mg_model.compile(optimizer=Adam(lr=lr),
                         loss=model_losses,
                         loss_weights=loss_weights)
        mg_model.fit_generator(miccai2018_gen,
                               initial_epoch=initial_epoch,
                               epochs=nb_epochs,
                               callbacks=[save_callback],
                               steps_per_epoch=steps_per_epoch,
                               verbose=1)
コード例 #16
0
def train(model,
          model_dir,
          gpu_id,
          lr,
          n_iterations,
          reg_param,
          model_save_iter,
          batch_size=1):
    """
    model training function
    :param model: either vm1 or vm2 (based on CVPR 2018 paper)
    :param model_dir: the model directory to save to
    :param gpu_id: integer specifying the gpu to use
    :param lr: learning rate
    :param n_iterations: number of training iterations
    :param reg_param: the smoothness/reconstruction tradeoff parameter (lambda in CVPR paper)
    :param model_save_iter: frequency with which to save models
    :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size
    """

    # prepare model folder
    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    # GPU handling
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    # UNET filters for voxelmorph-1 and voxelmorph-2,
    # these are architectures presented in CVPR 2018
    nf_enc = [16, 32, 32, 32]
    if model == 'vm1':
        nf_dec = [32, 32, 32, 32, 8, 8]
    else:
        nf_dec = [32, 32, 32, 32, 32, 16, 16]

    # prepare the model
    # in the CVPR layout, the model takes in [image_1, image_2] and outputs [warped_image_1, flow]
    # in the experiments, we use image_2 as atlas
    model = networks.unet(vol_size, nf_enc, nf_dec)
    model.compile(optimizer=Adam(lr=lr),
                  loss=[losses.cc3D(),
                        losses.gradientLoss('l2')],
                  loss_weights=[1.0, reg_param])

    # if you'd like to initialize the data, you can do it here:
    # model.load_weights(os.path.join(model_dir, '120000.h5'))

    # prepare data for training
    train_example_gen = datagenerators.example_gen(train_vol_names)
    zero_flow = np.zeros([batch_size, *vol_size, 3])

    # train. Note: we use train_on_batch and design out own print function as this has enabled
    # faster development and debugging, but one could also use fit_generator and Keras callbacks.
    for step in range(0, n_iterations):

        # get data
        X = next(train_example_gen)[0]

        # train
        train_loss = model.train_on_batch([X, atlas_vol],
                                          [atlas_vol, zero_flow])
        if not isinstance(train_loss, list):
            train_loss = [train_loss]

        # print the loss.
        print_loss(step, 1, train_loss)

        # save model
        if step % model_save_iter == 0:
            model.save(os.path.join(model_dir, str(step) + '.h5'))
コード例 #17
0
def train_unsupervised_segmentation(data_dir,
                                    atlas_file,
                                    mapping_file,
                                    model,
                                    model_dir,
                                    gpu_id,
                                    lr,
                                    nb_epochs,
                                    init_stats,
                                    reg_param,
                                    steps_per_epoch,
                                    batch_size,
                                    stat_post_warp,
                                    warp_method,
                                    load_model_file,
                                    initial_epoch=0):
    """
    model training function
    :param data_dir: folder with npz files (coregistered, intensity normalized)
    :param atlas_file: file with probabilistic atlas (coregistered to images)
    :param mapping_file: file with mapping from labels to tissue types
    :param model: registration (voxelmorph) model: vm1, vm2, or vm2double
    :param model_dir: the model directory to save to
    :param gpu_id: integer specifying the gpu to use
    :param lr: learning rate
    :param nb_epochs: number of epochs
    :param init_stats: file with guesses for means and log-variances (vectors init_mu, init_sigma)
    :param reg_param: smoothness/reconstruction tradeoff parameter (lambda in the paper)
    :param steps_per_epoch: frequency with which to save models
    :param batch_size: default of 1. can be larger, depends on GPU memory and volume size
    :param stat_post_warp: set to 1  to use warped atlas to estimate Gaussian parameters
    :param warp_method: set to 'WARP' if you want to warp the atlas
    :param load_model_file: optional h5 model file to initialize with
    :param initial_epoch: initial epoch
    """

    # load reference soft edge and corresponding mask from provided files
    # (we used size 160x192x224).
    # Also: group labels in tissue types, if necessary
    if mapping_file is None:
        atlas_vol = np.load(atlas_file)['vol_data'][np.newaxis, ...]
        nb_labels = atlas_vol.shape[-1]

    else:
        atlas_full = np.load(atlas_file)['vol_data'][np.newaxis, ...]

        mapping = np.load(mapping_file)['mapping'].astype('int').flatten()
        assert len(mapping) == atlas_full.shape[-1], \
            'mapping shape %d is inconsistent with atlas shape %d' % (len(mapping), atlas_full.shape[-1])

        nb_labels = 1 + np.max(mapping)
        atlas_vol = np.zeros(
            [1, *atlas_full.shape[1:-1],
             nb_labels.astype('int')])
        for j in range(np.max(mapping.shape)):
            atlas_vol[0, ...,
                      mapping[j]] = atlas_vol[0, ...,
                                              mapping[j]] + atlas_full[0, ...,
                                                                       j]

    vol_size = atlas_vol.shape[1:-1]

    # load guesses for means and variances
    init_mu = np.load(init_stats)['init_mu']
    init_sigma = np.load(init_stats)['init_std']

    # prepare data files
    train_vol_names = glob.glob(os.path.join(data_dir, '*.npz'))
    random.shuffle(train_vol_names)
    assert len(train_vol_names) > 0, "Could not find any training data"

    # UNET filters for voxelmorph-1 and voxelmorph-2,
    # these are architectures presented in CVPR 2018
    nf_enc = [16, 32, 32, 32]
    if model == 'vm1':
        nf_dec = [32, 32, 32, 32, 8, 8]
    elif model == 'vm2':
        nf_dec = [32, 32, 32, 32, 32, 16, 16]
    else:  # 'vm2double':
        nf_enc = [f * 2 for f in nf_enc]
        nf_dec = [f * 2 for f in [32, 32, 32, 32, 32, 16, 16]]

    # prepare model and log folders
    if not os.path.isdir(model_dir):
        os.mkdir(model_dir)

    log_dir = os.path.join(model_dir, 'logs')
    if not os.path.isdir(log_dir):
        os.mkdir(log_dir)

    # GPU handling
    gpu = '/gpu:%d' % 0  # gpu_id
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    set_session(tf.Session(config=config))

    # prepare the model
    with tf.device(gpu):
        # prepare the model
        model = networks.cvpr2018_net_probatlas(vol_size,
                                                nf_enc,
                                                nf_dec,
                                                nb_labels,
                                                diffeomorphic=True,
                                                full_size=False,
                                                stat_post_warp=stat_post_warp,
                                                warp_method=warp_method,
                                                init_mu=init_mu,
                                                init_sigma=init_sigma)

        # load initial weights
        if load_model_file is not None:
            print('loading', load_model_file)
            model.load_weights(load_model_file)

        # save first iteration
        model.save(os.path.join(model_dir, '%02d.h5' % initial_epoch))

    # data generator
    nb_gpus = len(gpu_id.split(','))
    assert np.mod(batch_size, nb_gpus) == 0, \
        'batch_size should be a multiple of the nr. of gpus. ' + \
        'Got batch_size %d, %d gpus' % (batch_size, nb_gpus)

    train_example_gen = datagenerators.example_gen(train_vol_names,
                                                   batch_size=batch_size)
    atlas_vol_bs = np.repeat(atlas_vol, batch_size, axis=0)
    cvpr2018_gen = datagenerators.cvpr2018_gen(train_example_gen,
                                               atlas_vol_bs,
                                               batch_size=batch_size)

    # prepare callbacks
    save_file_name = os.path.join(model_dir, '{epoch:02d}.h5')

    # fit generator
    with tf.device(gpu):

        # multi-gpu support
        if nb_gpus > 1:
            save_callback = nrn_gen.ModelCheckpointParallel(save_file_name)
            mg_model = multi_gpu_model(model, gpus=nb_gpus)

        # single-gpu
        else:
            save_callback = ModelCheckpoint(save_file_name)
            mg_model = model

        # tensorBoard callback
        tensorboard = TensorBoard(log_dir=log_dir,
                                  histogram_freq=0,
                                  write_graph=True,
                                  write_images=False)

        # compile loss and parameters
        def data_loss(_, yp):
            m = tf.cast(model.inputs[0] > 0, tf.float32)
            return -K.sum(yp * m) / K.sum(m)

        if warp_method != 'WARP':
            reg_param = 0

        # compile
        mg_model.compile(optimizer=Adam(lr=lr),
                         loss=[data_loss, losses.Grad('l2').loss],
                         loss_weights=[1.0, reg_param])

        # fit
        mg_model.fit_generator(cvpr2018_gen,
                               initial_epoch=initial_epoch,
                               epochs=nb_epochs,
                               callbacks=[save_callback, tensorboard],
                               steps_per_epoch=steps_per_epoch,
                               verbose=1)