def train(model_name, gpu_id): model_dir = '../models/' + model_name if not os.path.isdir(model_dir): os.mkdir(model_dir) gpu = '/gpu:' + str(gpu_id) os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True set_session(tf.Session(config=config)) with tf.device(gpu): model = networks.unet(vol_size, nf_enc, nf_dec) model.compile(optimizer=Adam(lr=lr), loss=[ losses.cc3D(), losses.gradientLoss('l2')], loss_weights=[1.0, reg_param]) # model.load_weights('../models/udrnet2/udrnet1_1/120000.h5') train_example_gen = datagenerators.example_gen(train_vol_names) zero_flow = np.zeros((1, vol_size[0], vol_size[1], vol_size[2], 3)) for step in xrange(0, n_iterations): X = train_example_gen.next()[0] train_loss = model.train_on_batch( [X, atlas_vol], [atlas_vol, zero_flow]) if not isinstance(train_loss, list): train_loss = [train_loss] printLoss(step, 1, train_loss) if(step % model_save_iter == 0): model.save(model_dir + '/' + str(step) + '.h5')
def train(model, gpu_id, lr, n_iterations, reg_param, model_save_iter, load_iter): model_dir = '/home/ys895/MAS3_Models' if not os.path.isdir(model_dir): os.mkdir(model_dir) gpu = '/gpu:' + str(gpu_id) os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True set_session(tf.Session(config=config)) # UNET filters nf_enc = [16, 32, 32, 32] if (model == 'vm1'): nf_dec = [32, 32, 32, 32, 8, 8, 3] else: nf_dec = [32, 32, 32, 32, 32, 16, 16, 3] with tf.device(gpu): model = networks.unet(vol_size, nf_enc, nf_dec) if (load_iter != 0): model.load_weights('/home/ys895/MAS3_Models/' + str(load_iter) + '.h5') model.compile(optimizer=Adam(lr=lr), loss=[losses.cc3D(), losses.gradientLoss('l2')], loss_weights=[1.0, reg_param]) # model.load_weights('../models/udrnet2/udrnet1_1/120000.h5') # return the data, add one more dimension into the data train_example_gen = datagenerators.example_gen(train_vol_names) zero_flow = np.zeros((1, vol_size[0], vol_size[1], vol_size[2], 3)) # In this part, the code inputs the data into the model # Before this part, the model was set for step in range(1, n_iterations + 1): # choose randomly one of the atlas from the atlas_list rand_num = random.randint(0, list_num - 1) atlas_vol = atlas_list[rand_num] #Parameters for training : X(train_vol) ,atlas_vol(atlas) ,zero_flow X = train_example_gen.__next__()[0] train_loss = model.train_on_batch([atlas_vol, X], [X, zero_flow]) if not isinstance(train_loss, list): train_loss = [train_loss] printLoss(step, 1, train_loss) if (step % model_save_iter == 0): model.save(model_dir + '/' + str(load_iter + step) + '.h5')
def test(model_name, iter_num, gpu_id, vol_size=(160, 192, 224), nf_enc=[16, 32, 32, 32], nf_dec=[32, 32, 32, 32]): """ test nf_enc and nf_dec #nf_dec = [32,32,32,32,32,16,16,3] # This needs to be changed. Ideally, we could just call load_model, and we wont have to # specify the # of channels here, but the load_model is not working with the custom loss... """ gpu = '/gpu:' + str(gpu_id) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True set_session(tf.Session(config=config)) # load weights of model with tf.device(gpu): full_model, train_model = networks.autoencoder(vol_size, nf_enc, nf_dec) full_model.load_weights('../models/' + model_name + '/' + str(iter_num) + '.h5') val_example_gen = datagenerators.example_gen(val_vol_names) # train. Note: we use train_on_batch and design out own print function as this has enabled # faster development and debugging, but one could also use fit_generator and Keras callbacks. total_loss = 0 for step in range(1): # get data X = next(val_example_gen)[0] # get output output, enc = full_model.predict([X]) loss = tf.reduce_mean(tf.square(output - X)) # print the loss. print(step, 0, loss) total_loss += loss slices(output[0]) print(total_loss)
def test(gpu, ref_dir, mov_dir, model, init_model_file): """ model training function :param gpu: integer specifying the gpu to use :param atlas_file: atlas filename. So far we support npz file with a 'vol' variable :param model: either vm1 or vm2 (based on CVPR 2018 paper) :param init_model_file: the model directory to load from """ os.environ["CUDA_VISIBLE_DEVICES"] = gpu device = "cuda" # Prepare the vm1 or vm2 model and send to device nf_enc = [16, 32, 32, 32] if model == "vm1": nf_dec = [32, 32, 32, 32, 8, 8] elif model == "vm2": nf_dec = [32, 32, 32, 32, 32, 16, 16] # Set up model vol_size = [2912, 2912] model = cvpr2018_net(vol_size, nf_enc, nf_dec) model.to(device) model.load_state_dict( torch.load(init_model_file, map_location=lambda storage, loc: storage)) # set up ref_vol_names = glob.glob(os.path.join(ref_dir, '*.npy')) mov_vol_names = glob.glob(os.path.join(mov_dir, '*npy')) nums = len(ref_vol_names) for k in range(0, nums): refs, movs = datagenerators.example_gen(ref_vol_names, mov_vol_names, batch_size=1) input_fixed = torch.from_numpy(refs).to(device).float() input_fixed = input_fixed.permute(0, 3, 1, 2) input_moving = torch.from_numpy(movs).to(device).float() input_moving = input_moving.permute(0, 3, 1, 2) # Use this to warp segments # trf = SpatialTransformer(input_fixed.shape[2:], mode='nearest') # trf.to(device) warp, flow = model(input_moving, input_fixed) flow_save = sitk.GetImageFromArray(flow.cpu().detach().numpy()) # sitk.WriteImage(flow_save,'D:\peizhunsd\data\\flow_img\\' + str(k) + '.nii') # 位移向量场的可视化 # addimage(input_fixed,input_moving,warp,k) # 可视化结果 dice_score = metrics.dice_score(warp, input_fixed) print('相似性度量dice:', dice_score)
def train(model,save_name, gpu_id, lr, n_iterations, reg_param, model_save_iter): model_dir = '../models/' + save_name if not os.path.isdir(model_dir): os.mkdir(model_dir) gpu = '/gpu:' + str(gpu_id) os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True set_session(tf.Session(config=config)) # UNET filters nf_enc = [16,32,32,32] if(model == 'vm1'): nf_dec = [32,32,32,32,8,8,3] else: nf_dec = [32,32,32,32,32,16,16,3] with tf.device(gpu): model = networks.unet(vol_size, nf_enc, nf_dec) model.compile(optimizer=Adam(lr=lr), loss=[ losses.cc3D(), losses.gradientLoss('l2')], loss_weights=[1.0, reg_param]) # model.load_weights('../models/udrnet2/udrnet1_1/120000.h5') train_example_gen = datagenerators.example_gen(train_vol_names) zero_flow = np.zeros((1, vol_size[0], vol_size[1], vol_size[2], 3)) for step in range(0, n_iterations): X = train_example_gen.__next__()[0] train_loss = model.train_on_batch( [X, atlas_vol], [atlas_vol, zero_flow]) if not isinstance(train_loss, list): train_loss = [train_loss] printLoss(step, 1, train_loss) if(step % model_save_iter == 0): model.save(model_dir + '/' + str(step) + '.h5')
def train(data_dir, fixed_image, model_dir, device, lr, nb_epochs, AAN_param, steps_per_epoch, batch_size, load_model_file, initial_epoch, DLR_model): # prepare data files # inside the folder are npz files with the 'vol' and 'label'. train_vol_names = glob.glob(os.path.join(data_dir, '*.npz')) random.shuffle(train_vol_names) # shuffle volume list_ assert len(train_vol_names) > 0, "Could not find any training data" vol_size = [144, 192, 160] # load atlas from provided files, if atlas-based registration if fixed_image != './': fixed_vol = np.load(fixed_image)['vol'][np.newaxis, ..., np.newaxis] def FAIM_loss(y_true, y_pred): return losses.Grad('l2').loss( y_true, y_pred) + 1e-5 * losses.NJ_loss(y_true, y_pred) assert DLR_model in [ 'VM', 'FAIM' ], 'DLR_model should be one of VM or FAIM, found %s' % LBR_model if DLR_model == 'FAIM': reg_loss = FAIM_loss else: reg_loss = losses.Grad('l2').loss # prepare model folder if not os.path.isdir(model_dir): os.mkdir(model_dir) # device handling if 'gpu' in device: if '0' in device: device = '/gpu:0' if '1' in device: device = '/gpu:1' config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True set_session(tf.Session(config=config)) else: device = '/cpu:0' # prepare the model with tf.device(device): model = networks.AAN(vol_size, DLR_model) # load initial weights if load_model_file != './': print('loading', load_model_file) model.load_weights(load_model_file) # save first iteration model.save(os.path.join(model_dir, '%02d.h5' % initial_epoch)) # data generator train_example_gen = datagenerators.example_gen(train_vol_names, batch_size=batch_size, return_boundary=True) if fixed_image != './': fixed_vol_bs = np.repeat(fixed_vol, batch_size, axis=0) data_gen = datagenerators.gen_atlas(train_example_gen, fixing_vol_bs, batch_size=batch_size) else: data_gen = datagenerators.gen_s2s(train_example_gen, batch_size=batch_size) # prepare callbacks save_file_name = os.path.join(model_dir, '{epoch:02d}.h5') # fit generator with tf.device(device): save_callback = ModelCheckpoint(save_file_name) # compile model.compile( optimizer=Adam(lr=lr), loss=['mse', reg_loss, losses.Grad('l1').loss_with_boundary], loss_weights=[1.0, 0.01, AAN_param]) # fit model.fit_generator(data_gen, initial_epoch=initial_epoch, epochs=nb_epochs, callbacks=[save_callback], steps_per_epoch=steps_per_epoch, verbose=1)
def train(model, pretrained_path, model_name, gpu_id, lr, n_iterations, use_mi, gamma, num_bins, patch_size, max_clip, reg_param, model_save_iter, local_mi, sigma_ratio, batch_size=1): """ model training function :param model: either vm1 or vm2 (based on CVPR 2018 paper) :param model_dir: the model directory to save to :param gpu_id: integer specifying the gpu to use :param lr: learning rate :param n_iterations: number of training iterations :param reg_param: the smoothness/reconstruction tradeoff parameter (lambda in CVPR paper) :param model_save_iter: frequency with which to save models :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size """ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' restrict_GPU_tf(str(gpu_id)) restrict_GPU_keras(str(gpu_id)) train_labels = sio.loadmat('../data/labels.mat')['labels'][0] n_labels = train_labels.shape[0] normalized_atlas_vol = atlas_vol / np.max(atlas_vol) * max_clip atlas_seg = datagenerators.split_seg_into_channels(seg, train_labels) atlas_seg = datagenerators.downsample(atlas_seg) model_dir = "../models/" + model_name # prepare model folder if not os.path.isdir(model_dir): os.mkdir(model_dir) # GPU handling os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True set_session(tf.Session(config=config)) # UNET filters for voxelmorph-1 and voxelmorph-2, # these are architectures presented in CVPR 2018 nf_enc = [16, 32, 32, 32] if model == 'vm1': nf_dec = [32, 32, 32, 32, 8, 8] else: nf_dec = [32, 32, 32, 32, 32, 16, 16] # prepare the model # in the CVPR layout, the model takes in [image_1, image_2] and outputs [warped_image_1, flow] # in the experiments, we use image_2 as atlas bin_centers = np.linspace(0, max_clip, num_bins * 2 + 1)[1::2] loss_function = losses.mutualInformation(bin_centers, max_clip=max_clip, local_mi=local_mi, patch_size=patch_size, sigma_ratio=sigma_ratio) model = networks.cvpr2018_net(vol_size, nf_enc, nf_dec, use_seg=True, n_seg=len(train_labels)) model.compile(optimizer=Adam(lr=lr), loss=[ loss_function, losses.gradientLoss('l2'), sparse_categorical_crossentropy ], loss_weights=[1 if use_mi else 0, reg_param, gamma]) # if you'd like to initialize the data, you can do it here: if pretrained_path != None and pretrained_path != '': model.load_weights(pretrained_path) # prepare data for training train_example_gen = datagenerators.example_gen(train_vol_names, return_segs=True, seg_dir=train_seg_dir) zero_flow = np.zeros([batch_size, *vol_size, 3]) # train. Note: we use train_on_batch and design out own print function as this has enabled # faster development and debugging, but one could also use fit_generator and Keras callbacks. for step in range(0, n_iterations): # get data X = next(train_example_gen) X_seg = X[1] X_seg = datagenerators.split_seg_into_channels(X_seg, train_labels) X_seg = datagenerators.downsample(X_seg) # train train_loss = model.train_on_batch( [X[0], normalized_atlas_vol, X_seg], [normalized_atlas_vol, zero_flow, atlas_seg]) if not isinstance(train_loss, list): train_loss = [train_loss] # print the loss. print_loss(step, 1, train_loss) # save model if step % model_save_iter == 0: model.save(os.path.join(model_dir, str(step) + '.h5'))
def train(data_dir, atlas_file, model_dir, model, gpu_id, lr, nb_epochs, prior_lambda, image_sigma, mean_lambda, steps_per_epoch, batch_size, load_model_file, bidir, atlas_wt, bias_mult, smooth_pen_layer, data_loss, reg_param, ncc_win, initial_epoch=0): """ model training function :param data_dir: folder with npz files for each subject. :param atlas_file: atlas filename. So far we support npz file with a 'vol' variable :param model_dir: model folder to save to :param gpu_id: integer specifying the gpu to use :param lr: learning rate :param nb_epochs: number of training iterations :param prior_lambda: the prior_lambda, the scalar in front of the smoothing laplacian, in MICCAI paper :param image_sigma: the image sigma in MICCAI paper :param steps_per_epoch: frequency with which to save models :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size :param load_model_file: optional h5 model file to initialize with :param bidir: logical whether to use bidirectional cost function """ # prepare data files # we have data arranged in train/validate/test folders # inside each folder is a /vols/ and a /asegs/ folder with the volumes # and segmentations. All of our papers use npz formated data. train_vol_names = glob.glob(data_dir) train_vol_names = [f for f in train_vol_names if 'ADNI' not in f] random.shuffle(train_vol_names) # shuffle volume list assert len(train_vol_names) > 0, "Could not find any training data" # data generator train_example_gen = datagenerators.example_gen(train_vol_names, batch_size=batch_size) # prepare the initial weights for the atlas "layer" if atlas_file is None or atlas_file == "": nb_atl_creation = 100 print('creating "atlas" by averaging %d subjects' % nb_atl_creation) x_avg = 0 for _ in range(nb_atl_creation): x_avg += next(train_example_gen)[0][0,...,0] x_avg /= nb_atl_creation x_avg = x_avg[np.newaxis,...,np.newaxis] atlas_vol = x_avg else: atlas_vol = np.load(atlas_file)['vol'][np.newaxis, ..., np.newaxis] vol_size = atlas_vol.shape[1:-1] # Diffeomorphic network architecture used in MICCAI 2018 paper nf_enc = [16,32,32,32] nf_dec = [32,32,32,32,16,3] if model == 'm1': pass elif model == 'm1double': nf_enc = [f*2 for f in nf_enc] nf_dec = [f*2 for f in nf_dec] # prepare model folder if not os.path.isdir(model_dir): os.mkdir(model_dir) assert data_loss in ['mse', 'cc', 'ncc'], 'Loss should be one of mse or cc, found %s' % data_loss if data_loss in ['ncc', 'cc']: data_loss = losses.NCC(win=[ncc_win]*3).loss else: data_loss = lambda y_t, y_p: K.mean(K.square(y_t-y_p)) # gpu handling gpu = '/gpu:' + str(gpu_id) os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True set_session(tf.Session(config=config)) # prepare the model with tf.device(gpu): # the MICCAI201 model takes in [image_1, image_2] and outputs [warped_image_1, velocity_stats] # in these experiments, we use image_2 as atlas model = networks.img_atlas_diff_model(vol_size, nf_enc, nf_dec, atl_mult=1, bidir=bidir, smooth_pen_layer=smooth_pen_layer) # compile mean_layer_loss = lambda _, y_pred: mean_lambda * K.mean(K.square(y_pred)) flow_vol_shape = model.outputs[-2].shape[1:-1] loss_class = losses.Miccai2018(image_sigma, prior_lambda, flow_vol_shape=flow_vol_shape) if bidir: model_losses = [data_loss, lambda _,y_p: data_loss(model.get_layer('atlas').output, y_p), mean_layer_loss, losses.Grad('l2').loss] loss_weights = [atlas_wt, 1-atlas_wt, 1, reg_param] else: model_losses = [loss_class.recon_loss, loss_class.kl_loss, mean_layer_loss] loss_weights = [1, 1, 1] model.compile(optimizer=Adam(lr=lr), loss=model_losses, loss_weights=loss_weights) # set initial weights in model model.get_layer('atlas').set_weights([atlas_vol[0,...]]) # load initial weights. # note this overloads the img_param weights if load_model_file is not None and len(load_model_file) > 0: model.load_weights(load_model_file, by_name=True) # save first iteration model.save(os.path.join(model_dir, '%02d.h5' % initial_epoch)) # atlas_generator specific to this model. Once we're convinced of this, move to datagenerators def atl_gen(gen): zero_flow = np.zeros([batch_size, *vol_size, len(vol_size)]) zero_flow_half = np.zeros([batch_size] + [f//2 for f in vol_size] + [len(vol_size)]) while 1: x2 = next(train_example_gen)[0] # TODO: note this is the opposite of train_miccai and it might be confusing. yield ([atlas_vol, x2], [x2, atlas_vol, zero_flow, zero_flow]) atlas_gen = atl_gen(train_example_gen) # prepare callbacks save_file_name = os.path.join(model_dir, '{epoch:02d}.h5') save_callback = ModelCheckpoint(save_file_name) # fit generator with tf.device(gpu): model.fit_generator(atlas_gen, initial_epoch=initial_epoch, epochs=nb_epochs, callbacks=[save_callback], steps_per_epoch=steps_per_epoch, verbose=1)
def train(gpu, data_dir, atlas_file, lr, n_iter, data_loss, model, reg_param, batch_size, n_save_iter, model_dir): """ model training function :param gpu: integer specifying the gpu to use :param data_dir: folder with npz files for each subject. :param atlas_file: atlas filename. So far we support npz file with a 'vol' variable :param lr: learning rate :param n_iter: number of training iterations :param data_loss: data_loss: 'mse' or 'ncc :param model: either vm1 or vm2 (based on CVPR 2018 paper) :param reg_param: the smoothness/reconstruction tradeoff parameter (lambda in CVPR paper) :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size :param n_save_iter: Optional, default of 500. Determines how many epochs before saving model version. :param model_dir: the model directory to save to """ os.environ["CUDA_VISIBLE_DEVICES"] = gpu device = "cuda" # Produce the loaded atlas with dims.:160x192x224. atlas_vol = np.load(atlas_file)['vol'][np.newaxis, ..., np.newaxis] vol_size = atlas_vol.shape[1:-1] # Get all the names of the training data train_vol_names = glob.glob(os.path.join(data_dir, '*.npz')) random.shuffle(train_vol_names) # Prepare the vm1 or vm2 model and send to device nf_enc = [16, 32, 32, 32] if model == "vm1": nf_dec = [32, 32, 32, 32, 8, 8] elif model == "vm2": nf_dec = [32, 32, 32, 32, 32, 16, 16] else: raise ValueError("Not yet implemented!") model = cvpr2018_net(vol_size, nf_enc, nf_dec) model.to(device) # Set optimizer and losses opt = Adam(model.parameters(), lr=lr) sim_loss_fn = losses.ncc_loss if data_loss == "ncc" else losses.mse_loss grad_loss_fn = losses.gradient_loss # data generator train_example_gen = datagenerators.example_gen(train_vol_names, batch_size) # set up atlas tensor atlas_vol_bs = np.repeat(atlas_vol, batch_size, axis=0) input_fixed = torch.from_numpy(atlas_vol_bs).to(device).float() input_fixed = input_fixed.permute(0, 4, 1, 2, 3) # Training loop. for i in range(n_iter): # Save model checkpoint if i % n_save_iter == 0: save_file_name = os.path.join(model_dir, '%d.ckpt' % i) torch.save(model.state_dict(), save_file_name) # Generate the moving images and convert them to tensors. moving_image = next(train_example_gen)[0] input_moving = torch.from_numpy(moving_image).to(device).float() input_moving = input_moving.permute(0, 4, 1, 2, 3) # Run the data through the model to produce warp and flow field warp, flow = model(input_moving, input_fixed) # Calculate loss recon_loss = sim_loss_fn(warp, input_fixed) grad_loss = grad_loss_fn(flow) loss = recon_loss + reg_param * grad_loss print("%d,%f,%f,%f" % (i, loss.item(), recon_loss.item(), grad_loss.item()), flush=True) # Backwards and optimize opt.zero_grad() loss.backward() opt.step()
def train(src_dir, tgt_dir, model_dir, model_lr_dir, lr, nb_epochs, reg_param, steps_per_epoch, batch_size, load_model_file=None, data_loss='ncc', initial_epoch=0): """ model training function :param data_dir: folder with npz files for each subject. :param atlas_file: atlas filename. So far we support npz file with a 'vol' variable :param model: either vm1 or vm2 (based on CVPR 2018 paper) :param model_dir: the model directory to save to :param lr: learning rate :param n_iterations: number of training iterations :param reg_param: the smoothness/reconstruction tradeoff parameter (lambda in CVPR paper) :param steps_per_epoch: frequency with which to save models :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size :param load_model_file: optional h5 model file to initialize with :param data_loss: data_loss: 'mse' or 'ncc """ # prepare data files # for the CVPR and MICCAI papers, we have data arranged in train/validate/test folders # inside each folder is a /vols/ and a /asegs/ folder with the volumes # and segmentations. All of our papers use npz formated data. src_vol_names = glob.glob(os.path.join(src_dir, '*.npz')) tgt_vol_names = glob.glob(os.path.join(tgt_dir, '*.npz')) random.shuffle(src_vol_names) # shuffle volume list random.shuffle(tgt_vol_names) # shuffle volume list assert len(src_vol_names) > 0, "Could not find any training data" assert data_loss in [ 'mse', 'ncc' ], 'Loss should be one of mse or cc, found %s' % data_loss if data_loss == 'ncc': data_loss = losses.NCC().loss # GPU handling config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True # set_session(tf.Session(config=config)) vol_size = (56, 56, 56) # prepare the model src_lr = tf.placeholder(tf.float32, [None, *vol_size, 1], name='input_src_lr') tgt_lr = tf.placeholder(tf.float32, [None, *vol_size, 1], name='input_tgt_lr') srm_lr = tf.placeholder(tf.float32, [None, *vol_size, 1], name='mask_src_lr') attn_lr = tf.placeholder(tf.float32, [None, *vol_size, 1], name='attn_lr') src_mr = tf.placeholder(tf.float32, [None, *vol_size, 1], name='input_src_mr') tgt_mr = tf.placeholder(tf.float32, [None, *vol_size, 1], name='input_tgt_mr') srm_mr = tf.placeholder(tf.float32, [None, *vol_size, 1], name='mask_src_mr') df_lr2mr = tf.placeholder(tf.float32, [None, *vol_size, 3], name='df_lr2mr') attn_mr = tf.placeholder(tf.float32, [None, *vol_size, 1], name='attn_mr') src_hr = tf.placeholder(tf.float32, [None, *vol_size, 1], name='input_src_hr') tgt_hr = tf.placeholder(tf.float32, [None, *vol_size, 1], name='input_tgt_hr') srm_hr = tf.placeholder(tf.float32, [None, *vol_size, 1], name='mask_src_hr') df_mr2hr = tf.placeholder(tf.float32, [None, *vol_size, 3], name='df_mr2hr') attn_hr = tf.placeholder(tf.float32, [None, *vol_size, 1], name='attn_hr') model_lr = networks.net_lr(src_lr, tgt_lr, srm_lr) model_mr = networks.net_mr(src_mr, tgt_mr, srm_mr, df_lr2mr) model_hr = networks.net_hr(src_hr, tgt_hr, srm_hr, df_mr2hr) # the loss functions lr_ncc = data_loss(model_lr[0].outputs, tgt_lr) #lr_grd = losses.Grad('l2').loss(model_lr[0].outputs, model_lr[2].outputs) lr_grd = losses.Anti_Folding('l2').loss(model_lr[0].outputs, model_lr[2].outputs) cost_lr = lr_ncc + reg_param * lr_grd # + lr_attn mr_ncc = data_loss(model_mr[0].outputs, tgt_mr) #mr_grd = losses.Grad('l2').loss(model_mr[0].outputs, model_mr[2].outputs) mr_grd = losses.Anti_Folding('l2').loss(model_mr[0].outputs, model_mr[2].outputs) cost_mr = mr_ncc + reg_param * mr_grd hr_ncc = data_loss(model_hr[0].outputs, tgt_hr) #hr_grd = losses.Grad('l2').loss(model_hr[0].outputs, model_hr[2].outputs) hr_grd = losses.Anti_Folding('l2').loss(model_hr[0].outputs, model_hr[2].outputs) cost_hr = hr_ncc + reg_param * hr_grd # the training operations def get_v(name): t_vars = tf.trainable_variables() d_vars = [var for var in t_vars if name in var.name] return d_vars #attn_vars = tl.layers.get_variables_with_name('cbam_1', True, True) attn_vars = get_v('cbam_1') for a_v in attn_vars: print(a_v) train_op_lr = tf.train.AdamOptimizer(lr).minimize(cost_lr) train_op_mr = tf.train.AdamOptimizer(lr).minimize(cost_mr) train_op_hr = tf.train.AdamOptimizer(lr).minimize(cost_hr) # data generator src_example_gen = datagenerators.example_gen(src_vol_names, batch_size=batch_size) tgt_example_gen = datagenerators.example_gen(tgt_vol_names, batch_size=batch_size) data_gen = datagenerators.gen_with_mask(src_example_gen, tgt_example_gen, batch_size=batch_size) variables_to_restore = tf.contrib.framework.get_variables_to_restore( exclude=['net_hr']) saver = tf.train.Saver(variables_to_restore) #saver = tf.train.Saver(max_to_keep=3) # fit generator with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) # load initial weights try: if load_model_file is not None: model_file = tf.train.latest_checkpoint(load_model_file) # saver.restore(sess, model_file) except: print('No files in', load_model_file) saver.save(sess, model_dir + 'dfnet', global_step=0) def resize_df(df, zoom): df1 = nd.interpolation.zoom( df[0, :, :, :, 0], zoom=zoom, mode='nearest', order=3) * zoom[ 0] # Cubic: order=3; Bilinear: order=1; Nearest: order=0 df2 = nd.interpolation.zoom( df[0, :, :, :, 1], zoom=zoom, mode='nearest', order=3) * zoom[1] df3 = nd.interpolation.zoom( df[0, :, :, :, 2], zoom=zoom, mode='nearest', order=3) * zoom[2] dfs = np.stack((df1, df2, df3), axis=3) return dfs[np.newaxis, :, :, :] class logPrinter(object): def __init__(self): self.n_batch = 0 self.total_dice = [] self.cost = [] self.ncc = [] self.grd = [] def addLog(self, dice, cost, ncc, grd): self.n_batch += 1 self.dice.append(dice) self.cost.append(cost) self.ncc.append(ncc) self.grd.append(grd) def output(self): dice = np.array(self.dice).mean(axis=0).round(3).tolist() cost = np.array(self.cost).mean() ncc = np.array(self.ncc).mean() grd = np.array(self.grd).mean() return dice, cost, ncc, grd, self.n_batch def clear(self): self.n_batch = 0 self.dice = [] self.cost = [] self.ncc = [] self.grd = [] lr_log = logPrinter() mr_log = logPrinter() hr_log = logPrinter() # train low resolution # load initial weights saver = tf.train.Saver(max_to_keep=1) #if model_lr_dir is not None: # model_lr_dir = tf.train.latest_checkpoint(model_lr_dir) # # print(model_lr_dir) # saver.restore(sess, model_lr_dir) nb_epochs = 20 #20#10 steps_per_epoch = 30 * 29 for epoch in range(nb_epochs): tbar = trange(steps_per_epoch, unit='batch', ncols=100) lr_log.clear() for i in tbar: image, mask = data_gen.__next__() global_X, global_atlas = image global_X_mask, global_atlas_mask = mask global_diff = global_X[0, :, :, :, 0] - global_atlas[0, :, :, :, 0] # low resolution global_X_64 = nd.interpolation.zoom(global_X[0, :, :, :, 0], zoom=(0.25, 0.25, 0.25), mode='nearest') global_A_64 = nd.interpolation.zoom(global_atlas[0, :, :, :, 0], zoom=(0.25, 0.25, 0.25), mode='nearest') global_XM_64 = nd.interpolation.zoom(global_X_mask[0, :, :, :, 0], zoom=(0.25, 0.25, 0.25), mode='nearest', order=0) global_AM_64 = nd.interpolation.zoom( global_atlas_mask[0, :, :, :, 0], zoom=(0.25, 0.25, 0.25), mode='nearest', order=0) global_diff_16 = nd.interpolation.zoom(global_diff, zoom=(0.25, 0.25, 0.25), mode='nearest') global_X_64 = global_X_64[np.newaxis, :, :, :, np.newaxis] global_A_64 = global_A_64[np.newaxis, :, :, :, np.newaxis] global_XM_64 = global_XM_64[np.newaxis, :, :, :, np.newaxis] global_AM_64 = global_AM_64[np.newaxis, :, :, :, np.newaxis] global_diff_16 = global_diff_16[np.newaxis, :, :, :, np.newaxis] feed_dict = { src_lr: global_X_64, tgt_lr: global_A_64, srm_lr: global_XM_64, attn_lr: global_diff_16 } err_lr, _ = sess.run([cost_lr, train_op_lr], feed_dict=feed_dict) df_lr, warp_seg, elr_ncc, elr_grad, lr_attn_map, lr_attn_feature = sess.run( [ model_lr[2].outputs, model_lr[1].outputs, lr_ncc, lr_grd, model_lr[3], model_lr[4] ], feed_dict=feed_dict) # print(df_lr.shape) lr_dice, _ = dice(warp_seg[0, :, :, :, 0], global_AM_64[0, :, :, :, 0], labels=[0, 10, 150, 250], nargout=2) lr_log.addLog(lr_dice, err_lr, elr_ncc, elr_grad) lr_out = lr_log.output() tbar.set_description('Epoch %d/%d ### step %i' % (epoch + 1, nb_epochs, i)) tbar.set_postfix(lr_dice=lr_out[0], lr_cost=lr_out[1], lr_ncc=lr_out[2], lr_grd=lr_out[3]) saver.save(sess, model_lr_dir + 'dfnet', global_step=0) # train middle resolution nb_epochs = 1 #1 steps_per_epoch = 30 * 29 for epoch in range(nb_epochs): lr_log.clear() for lr_step in range(steps_per_epoch): image, mask = data_gen.__next__() global_X, global_atlas = image global_X_mask, global_atlas_mask = mask global_diff = global_X[0, :, :, :, 0] - global_atlas[0, :, :, :, 0] # low resolution global_X_64 = nd.interpolation.zoom(global_X[0, :, :, :, 0], zoom=(0.25, 0.25, 0.25), mode='nearest') global_A_64 = nd.interpolation.zoom(global_atlas[0, :, :, :, 0], zoom=(0.25, 0.25, 0.25), mode='nearest') global_XM_64 = nd.interpolation.zoom(global_X_mask[0, :, :, :, 0], zoom=(0.25, 0.25, 0.25), mode='nearest', order=0) global_AM_64 = nd.interpolation.zoom( global_atlas_mask[0, :, :, :, 0], zoom=(0.25, 0.25, 0.25), mode='nearest', order=0) global_diff_16 = nd.interpolation.zoom(global_diff, zoom=(0.25, 0.25, 0.25), mode='nearest') global_X_64 = global_X_64[np.newaxis, :, :, :, np.newaxis] global_A_64 = global_A_64[np.newaxis, :, :, :, np.newaxis] global_XM_64 = global_XM_64[np.newaxis, :, :, :, np.newaxis] global_AM_64 = global_AM_64[np.newaxis, :, :, :, np.newaxis] global_diff_16 = global_diff_16[np.newaxis, :, :, :, np.newaxis] feed_dict = { src_lr: global_X_64, tgt_lr: global_A_64, srm_lr: global_XM_64, attn_lr: global_diff_16 } err_lr, _ = sess.run([cost_lr, train_op_lr], feed_dict=feed_dict) df_lr, warp_seg, elr_ncc, elr_grad, lr_attn_map, lr_attn_feature = sess.run( [ model_lr[2].outputs, model_lr[1].outputs, lr_ncc, lr_grd, model_lr[3], model_lr[4] ], feed_dict=feed_dict) lr_dice, _ = dice(warp_seg[0, :, :, :, 0], global_AM_64[0, :, :, :, 0], labels=[0, 10, 150, 250], nargout=2) lr_log.addLog(lr_dice, err_lr, elr_ncc, elr_grad) lr_out = lr_log.output() print('\nEpoch %d/%d ### step %i' % (epoch + 1, nb_epochs, lr_out[-1])) print( '[lr] lr_dice={}, lr_cost={:.3f}, lr_ncc={:.3f}, lr_grd={:.3f}' .format(lr_out[0], lr_out[1], lr_out[2], lr_out[3])) # middle part df_lr_res2mr = resize_df(df_lr, zoom=(2, 2, 2)) select_points_lr = patch_selection_attn(lr_attn_map, zoom_scales=[8, 8, 8], kernel=7, mi=10, ma=18) print(select_points_lr) mr_log.clear() for sp in select_points_lr: mov_img_112 = global_X[0, sp[0] - 56:sp[0] + 56, sp[1] - 56:sp[1] + 56, sp[2] - 56:sp[2] + 56, 0] fix_img_112 = global_atlas[0, sp[0] - 56:sp[0] + 56, sp[1] - 56:sp[1] + 56, sp[2] - 56:sp[2] + 56, 0] mov_seg_112 = global_X_mask[0, sp[0] - 56:sp[0] + 56, sp[1] - 56:sp[1] + 56, sp[2] - 56:sp[2] + 56, 0] fix_seg_112 = global_atlas_mask[0, sp[0] - 56:sp[0] + 56, sp[1] - 56:sp[1] + 56, sp[2] - 56:sp[2] + 56, 0] dif_img_112 = global_diff[sp[0] - 56:sp[0] + 56, sp[1] - 56:sp[1] + 56, sp[2] - 56:sp[2] + 56] #print(mov_img_112.shape) if fix_img_112.shape != (112, 112, 112): print(mov_img_112.shape) continue fix_112_56 = nd.interpolation.zoom(fix_img_112, zoom=(0.5, 0.5, 0.5), mode='nearest') mov_112_56 = nd.interpolation.zoom(mov_img_112, zoom=(0.5, 0.5, 0.5), mode='nearest') fix_112_56m = nd.interpolation.zoom(fix_seg_112, zoom=(0.5, 0.5, 0.5), mode='nearest', order=0) mov_112_56m = nd.interpolation.zoom(mov_seg_112, zoom=(0.5, 0.5, 0.5), mode='nearest', order=0) dif_112_56 = nd.interpolation.zoom(dif_img_112, zoom=(0.5, 0.5, 0.5), mode='nearest') mid_fix_img = fix_112_56[np.newaxis, :, :, :, np.newaxis] mid_mov_img = mov_112_56[np.newaxis, :, :, :, np.newaxis] mid_fix_seg = fix_112_56m[np.newaxis, :, :, :, np.newaxis] mid_mov_seg = mov_112_56m[np.newaxis, :, :, :, np.newaxis] mid_dif_img = dif_112_56[np.newaxis, :, :, :, np.newaxis] df_mr_feed = df_lr_res2mr[:, sp[0] // 2 - 28:sp[0] // 2 + 28, sp[1] // 2 - 28:sp[1] // 2 + 28, sp[2] // 2 - 28:sp[2] // 2 + 28, :] feed_dict = { src_mr: mid_mov_img, tgt_mr: mid_fix_img, srm_mr: mid_mov_seg, df_lr2mr: df_mr_feed, attn_mr: mid_dif_img } err_mr, _ = sess.run([cost_mr, train_op_mr], feed_dict=feed_dict) df_mr, warp_seg, emr_ncc, emr_grad = sess.run( [ model_mr[2].outputs, model_mr[1].outputs, mr_ncc, mr_grd ], feed_dict=feed_dict) mr_dice, _ = dice(warp_seg[0, :, :, :, 0], mid_fix_seg[0, :, :, :, 0], labels=[0, 10, 150, 250], nargout=2) mr_log.addLog(mr_dice, err_mr, emr_ncc, emr_grad) mr_out = mr_log.output() # print(' Epoch %d/%d ### step %i' % (epoch+1, nb_epochs, mr_out[-1])) print( ' [mr] {}/{} mr_dice={}, mr_cost={:.3f}, mr_ncc={:.3f}, mr_grd={:.3f}' .format(mr_out[-1], len(select_points_lr), mr_out[0], mr_out[1], mr_out[2], mr_out[3])) saver.save(sess, model_dir + 'dfnet', global_step=0) # train high resolution nb_epochs = 1 steps_per_epoch = 300 for epoch in range(nb_epochs): lr_log.clear() for lr_step in range(steps_per_epoch): image, mask = data_gen.__next__() global_X, global_atlas = image global_X_mask, global_atlas_mask = mask global_diff = global_X[0, :, :, :, 0] - global_atlas[0, :, :, :, 0] # low resolution global_X_64 = nd.interpolation.zoom(global_X[0, :, :, :, 0], zoom=(0.25, 0.25, 0.25), mode='nearest') global_A_64 = nd.interpolation.zoom(global_atlas[0, :, :, :, 0], zoom=(0.25, 0.25, 0.25), mode='nearest') global_XM_64 = nd.interpolation.zoom(global_X_mask[0, :, :, :, 0], zoom=(0.25, 0.25, 0.25), mode='nearest', order=0) global_AM_64 = nd.interpolation.zoom( global_atlas_mask[0, :, :, :, 0], zoom=(0.25, 0.25, 0.25), mode='nearest', order=0) global_diff_16 = nd.interpolation.zoom(global_diff, zoom=(0.25, 0.25, 0.25), mode='nearest') global_X_64 = global_X_64[np.newaxis, :, :, :, np.newaxis] global_A_64 = global_A_64[np.newaxis, :, :, :, np.newaxis] global_XM_64 = global_XM_64[np.newaxis, :, :, :, np.newaxis] global_AM_64 = global_AM_64[np.newaxis, :, :, :, np.newaxis] global_diff_16 = global_diff_16[np.newaxis, :, :, :, np.newaxis] feed_dict = { src_lr: global_X_64, tgt_lr: global_A_64, srm_lr: global_XM_64, attn_lr: global_diff_16 } err_lr, _ = sess.run([cost_lr, train_op_lr], feed_dict=feed_dict) df_lr, warp_seg, elr_ncc, elr_grad, lr_attn_map, lr_attn_feature = sess.run( [ model_lr[2].outputs, model_lr[1].outputs, lr_ncc, lr_grd, model_lr[3], model_lr[4] ], feed_dict=feed_dict) lr_dice, _ = dice(warp_seg[0, :, :, :, 0], global_AM_64[0, :, :, :, 0], labels=[0, 10, 150, 250], nargout=2) lr_log.addLog(lr_dice, err_lr, elr_ncc, elr_grad) lr_out = lr_log.output() print('\nEpoch %d/%d ### step %i' % (epoch + 1, nb_epochs, lr_out[-1])) print( '[lr] lr_dice={}, lr_cost={:.3f}, lr_ncc={:.3f}, lr_grd={:.3f}' .format(lr_out[0], lr_out[1], lr_out[2], lr_out[3])) # middle part df_lr_res2mr = resize_df(df_lr, zoom=(2, 2, 2)) select_points_lr = patch_selection_attn(lr_attn_map, zoom_scales=[8, 8, 8], kernel=7, mi=10, ma=18) print(select_points_lr) mr_log.clear() for sp in select_points_lr: mov_img_112 = global_X[0, sp[0] - 56:sp[0] + 56, sp[1] - 56:sp[1] + 56, sp[2] - 56:sp[2] + 56, 0] fix_img_112 = global_atlas[0, sp[0] - 56:sp[0] + 56, sp[1] - 56:sp[1] + 56, sp[2] - 56:sp[2] + 56, 0] mov_seg_112 = global_X_mask[0, sp[0] - 56:sp[0] + 56, sp[1] - 56:sp[1] + 56, sp[2] - 56:sp[2] + 56, 0] fix_seg_112 = global_atlas_mask[0, sp[0] - 56:sp[0] + 56, sp[1] - 56:sp[1] + 56, sp[2] - 56:sp[2] + 56, 0] dif_img_112 = global_diff[sp[0] - 56:sp[0] + 56, sp[1] - 56:sp[1] + 56, sp[2] - 56:sp[2] + 56] #print(mov_img_112.shape) if fix_img_112.shape != (112, 112, 112): print(mov_img_112.shape) continue fix_112_56 = nd.interpolation.zoom(fix_img_112, zoom=(0.5, 0.5, 0.5), mode='nearest') mov_112_56 = nd.interpolation.zoom(mov_img_112, zoom=(0.5, 0.5, 0.5), mode='nearest') fix_112_56m = nd.interpolation.zoom(fix_seg_112, zoom=(0.5, 0.5, 0.5), mode='nearest', order=0) mov_112_56m = nd.interpolation.zoom(mov_seg_112, zoom=(0.5, 0.5, 0.5), mode='nearest', order=0) dif_112_56 = nd.interpolation.zoom(dif_img_112, zoom=(0.5, 0.5, 0.5), mode='nearest') mid_fix_img = fix_112_56[np.newaxis, :, :, :, np.newaxis] mid_mov_img = mov_112_56[np.newaxis, :, :, :, np.newaxis] mid_fix_seg = fix_112_56m[np.newaxis, :, :, :, np.newaxis] mid_mov_seg = mov_112_56m[np.newaxis, :, :, :, np.newaxis] mid_dif_img = dif_112_56[np.newaxis, :, :, :, np.newaxis] df_mr_feed = df_lr_res2mr[:, sp[0] // 2 - 28:sp[0] // 2 + 28, sp[1] // 2 - 28:sp[1] // 2 + 28, sp[2] // 2 - 28:sp[2] // 2 + 28, :] feed_dict = { src_mr: mid_mov_img, tgt_mr: mid_fix_img, srm_mr: mid_mov_seg, df_lr2mr: df_mr_feed, attn_mr: mid_dif_img } err_mr, _ = sess.run([cost_mr, train_op_mr], feed_dict=feed_dict) df_mr, warp_seg, emr_ncc, emr_grad, mr_attn_map, mr_attn_feature = sess.run( [ model_mr[2].outputs, model_mr[1].outputs, mr_ncc, mr_grd, model_mr[3], model_mr[4] ], feed_dict=feed_dict) mr_dice, _ = dice(warp_seg[0, :, :, :, 0], mid_fix_seg[0, :, :, :, 0], labels=[0, 10, 150, 250], nargout=2) mr_log.addLog(mr_dice, err_mr, emr_ncc, emr_grad) mr_out = mr_log.output() # print(' Epoch %d/%d ### step %i' % (epoch+1, nb_epochs, mr_out[-1])) print( ' [mr] {}/{} mr_dice={}, mr_cost={:.3f}, mr_ncc={:.3f}, mr_grd={:.3f}' .format(mr_out[-1], len(select_points_lr), mr_out[0], mr_out[1], mr_out[2], mr_out[3])) # high part df_mr_res2hr = resize_df(df_mr, zoom=(2, 2, 2)) hr_log.clear() select_points_mr = patch_selection_attn( mr_attn_map, zoom_scales=[4, 4, 4], kernel=7, mi=8, ma=20) print(30 * '-') print('High Part') print(select_points_mr) for spm in select_points_mr: fix_img_56 = fix_img_112[spm[0] - 28:spm[0] + 28, spm[1] - 28:spm[1] + 28, spm[2] - 28:spm[2] + 28] mov_img_56 = mov_img_112[spm[0] - 28:spm[0] + 28, spm[1] - 28:spm[1] + 28, spm[2] - 28:spm[2] + 28] fix_seg_56 = fix_seg_112[spm[0] - 28:spm[0] + 28, spm[1] - 28:spm[1] + 28, spm[2] - 28:spm[2] + 28] mov_seg_56 = mov_seg_112[spm[0] - 28:spm[0] + 28, spm[1] - 28:spm[1] + 28, spm[2] - 28:spm[2] + 28] dif_img_56 = dif_img_112[spm[0] - 28:spm[0] + 28, spm[1] - 28:spm[1] + 28, spm[2] - 28:spm[2] + 28] if fix_img_56.shape != (56, 56, 56): continue hig_fix_img = fix_img_56[np.newaxis, :, :, :, np.newaxis] hig_mov_img = mov_img_56[np.newaxis, :, :, :, np.newaxis] hig_fix_seg = fix_seg_56[np.newaxis, :, :, :, np.newaxis] hig_mov_seg = mov_seg_56[np.newaxis, :, :, :, np.newaxis] hig_dif_img = dif_img_56[np.newaxis, :, :, :, np.newaxis] df_hr_feed = df_mr_res2hr[:, spm[0] - 28:spm[0] + 28, spm[1] - 28:spm[1] + 28, spm[2] - 28:spm[2] + 28, :] feed_dict = { src_hr: hig_mov_img, tgt_hr: hig_fix_img, srm_hr: hig_mov_seg, df_mr2hr: df_hr_feed, attn_hr: hig_dif_img } err_hr, _ = sess.run([cost_hr, train_op_hr], feed_dict=feed_dict) df_hr, warp_seg, ehr_ncc, ehr_grad = sess.run( [ model_hr[2].outputs, model_hr[1].outputs, hr_ncc, hr_grd ], feed_dict=feed_dict) hr_dice, _ = dice(warp_seg[0, :, :, :, 0], hig_fix_seg[0, :, :, :, 0], labels=[0, 10, 150, 250], nargout=2) hr_log.addLog(hr_dice, err_hr, ehr_ncc, ehr_grad) hr_out = hr_log.output() # print(' Epoch %d/%d ### step %i' % (epoch+1, nb_epochs, mr_out[-1])) print( ' [hr] {}/{} hr_dice={}, hr_cost={:.3f}, hr_ncc={:.3f}, hr_grd={:.3f}' .format(hr_out[-1], len(select_points_mr), hr_out[0], hr_out[1], hr_out[2], hr_out[3])) saver.save(sess, model_dir + 'dfnet', global_step=lr_step)
def train(gpu, data_dir, size, atlas_dir, lr, n_iter, data_loss, model, reg_param, batch_size, n_save_iter, model_dir, nr_val_data): """ model training function :param gpu: integer specifying the gpu to use :param data_dir: folder with npz files for each subject. :param size: int desired size of the volumes: [size,size,size] :param atlas_dir: direction to atlas folder :param lr: learning rate :param n_iter: number of training iterations :param data_loss: data_loss: 'mse' or 'ncc :param model: either vm1 or vm2 (based on CVPR 2018 paper) :param reg_param: the smoothness/reconstruction tradeoff parameter (lambda in CVPR paper) :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size :param n_save_iter: Optional, default of 500. Determines how many epochs before saving model version. :param model_dir: the model directory to save to :param nr_val_data: number of validation examples that should be separated from the training data """ os.environ["CUDA_VISIBLE_DEVICES"] = gpu device = "cuda" vol_size = np.array([size, size, size]) # Get all the names of the training data vol_names = glob.glob(os.path.join(data_dir, '*.nii')) #random.shuffle(vol_names) #test_vol_names = vol_names[-nr_val_data:] test_vol_names = vol_names[:nr_val_data] #test_vol_names = [i for i in test_vol_names if "L2-L4" in i] print( 'these volumes are separated from the data and serve as validation data : ' ) print(test_vol_names) #train_vol_names = vol_names[:-nr_val_data] train_vol_names = vol_names[nr_val_data:] #train_vol_names = [i for i in train_vol_names if "L2-L4" in i] random.shuffle(train_vol_names) writer = SummaryWriter(get_outputs_path()) # Prepare the vm1 or vm2 model and send to device nf_enc = [16, 32, 32, 32] if model == "vm1": nf_dec = [32, 32, 32, 32, 8, 8] elif model == "vm2": nf_dec = [32, 32, 32, 32, 32, 16, 16] else: raise ValueError("Not yet implemented!") model = cvpr2018_net(vol_size, nf_enc, nf_dec) model.to(device) # Set optimizer and losses opt = Adam(model.parameters(), lr=lr) sim_loss_fn = losses.ncc_loss if data_loss == "ncc" else losses.mse_loss grad_loss_fn = losses.gradient_loss # data generator train_example_gen = datagenerators.example_gen(train_vol_names, atlas_dir, size, batch_size) # Training loop. for i in range(n_iter): # Save model checkpoint and plot validation score if i % n_save_iter == 0: save_file_name = os.path.join(model_dir, '%d.ckpt' % i) torch.save(model.state_dict(), save_file_name) # load validation data val_example_gen = datagenerators.example_gen( test_vol_names, atlas_dir, size, 4) val_data = next(val_example_gen) val_fixed = torch.from_numpy(val_data[1]).to(device).float() val_fixed = val_fixed.permute(0, 4, 1, 2, 3) val_moving = torch.from_numpy(val_data[0]).to(device).float() val_moving = val_moving.permute(0, 4, 1, 2, 3) #create validation data for the model val_warp, val_flow = model(val_moving, val_fixed) #calculte validation score val_recon_loss = sim_loss_fn(val_warp, val_fixed) val_grad_loss = grad_loss_fn(val_flow) val_loss = val_recon_loss + reg_param * val_grad_loss #tensorboard writer.add_scalar('Loss/Test', val_loss, i) #prints print('validation') print("%d,%f,%f,%f" % (i, val_loss.item(), val_recon_loss.item(), val_grad_loss.item()), flush=True) # Generate the moving images and convert them to tensors. data_for_network = next(train_example_gen) input_fixed = torch.from_numpy(data_for_network[1]).to(device).float() input_fixed = input_fixed.permute(0, 4, 1, 2, 3) input_moving = torch.from_numpy(data_for_network[0]).to(device).float() input_moving = input_moving.permute(0, 4, 1, 2, 3) # Run the data through the model to produce warp and flow field warp, flow = model(input_moving, input_fixed) print("warp_and_flow_field") print(warp.size()) print(flow.size()) # Calculate loss recon_loss = sim_loss_fn(warp, input_fixed) grad_loss = grad_loss_fn(flow) loss = recon_loss + reg_param * grad_loss #tensorboard writer.add_scalar('Loss/Train', loss, i) print("%d,%f,%f,%f" % (i, loss.item(), recon_loss.item(), grad_loss.item()), flush=True) # Backwards and optimize opt.zero_grad() loss.backward() opt.step()
def train(data_dir, atlas_file, model, model_name, gpu_id, lr, nb_epochs, reg_param, steps_per_epoch, batch_size, load_model_file, data_loss, initial_epoch=0): """ model training function :param data_dir: folder with npz files for each subject. :param atlas_file: atlas filename. So far we support npz file with a 'vol' variable :param model: either vm1 or vm2 (based on CVPR 2018 paper) :param model_dir: the model directory to save to :param gpu_id: integer specifying the gpu to use :param lr: learning rate :param n_iterations: number of training iterations :param reg_param: the smoothness/reconstruction tradeoff parameter (lambda in CVPR paper) :param steps_per_epoch: frequency with which to save models :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size :param load_model_file: optional h5 model file to initialize with :param data_loss: data_loss: 'mse' or 'ncc """ # load atlas from provided files. The atlas we used is 160x192x224. # atlas_vol = np.load(atlas_file)['vol'][np.newaxis, ..., np.newaxis] atlas_vol = nib.load(atlas_file).get_data()[np.newaxis, ..., np.newaxis] vol_size = atlas_vol.shape[1:-1] # prepare data files # for the CVPR and MICCAI papers, we have data arranged in train/validate/test folders # inside each folder is a /vols/ and a /asegs/ folder with the volumes # and segmentations. All of our papers use npz formated data. train_vol_names = glob.glob(os.path.join(data_dir, '*.npz')) random.shuffle(train_vol_names) # shuffle volume list assert len(train_vol_names) > 0, "Could not find any training data" # UNET filters for voxelmorph-1 and voxelmorph-2, # these are architectures presented in CVPR 2018 nf_enc = [16, 32, 32, 32] if model == 'vm1': nf_dec = [32, 32, 32, 32, 8, 8] elif model == 'vm2': nf_dec = [32, 32, 32, 32, 32, 16, 16] else: # 'vm2double': nf_enc = [f * 2 for f in nf_enc] nf_dec = [f * 2 for f in [32, 32, 32, 32, 32, 16, 16]] assert data_loss in [ 'mse', 'cc', 'ncc' ], 'Loss should be one of mse or cc, found %s' % data_loss if data_loss in ['ncc', 'cc']: data_loss = losses.NCC().loss model_dir = "../models/" + model_name # prepare model folder if not os.path.isdir(model_dir): os.mkdir(model_dir) # GPU handling gpu = '/gpu:%d' % gpu_id os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True set_session(tf.Session(config=config)) # prepare the model with tf.device(gpu): # prepare the model # in the CVPR layout, the model takes in [image_1, image_2] and outputs [warped_image_1, flow] # in the experiments, we use image_2 as atlas model = networks.cvpr2018_net(vol_size, nf_enc, nf_dec) # load initial weights if load_model_file is not None and load_model_file != '': print('loading', load_model_file) model.load_weights(load_model_file) # save first iteration model.save(os.path.join(model_dir, '%02d.h5' % initial_epoch)) # data generator # nb_gpus = len(gpu_id.split(',')) # assert np.mod(batch_size, nb_gpus) == 0, \ # 'batch_size should be a multiple of the nr. of gpus. ' + \ # 'Got batch_size %d, %d gpus' % (batch_size, nb_gpus) nb_gpus = 1 train_example_gen = datagenerators.example_gen(train_vol_names, batch_size=batch_size) atlas_vol_bs = np.repeat(atlas_vol, batch_size, axis=0) cvpr2018_gen = datagenerators.cvpr2018_gen(train_example_gen, atlas_vol_bs, batch_size=batch_size) # prepare callbacks save_file_name = os.path.join(model_dir, '{epoch:02d}.h5') # fit generator with tf.device(gpu): # multi-gpu support if nb_gpus > 1: save_callback = nrn_gen.ModelCheckpointParallel(save_file_name) mg_model = multi_gpu_model(model, gpus=nb_gpus) # single-gpu else: save_callback = ModelCheckpoint(save_file_name) mg_model = model # compile mg_model.compile(optimizer=Adam(lr=lr), loss=[data_loss, losses.Grad('l2').loss], loss_weights=[1.0, reg_param]) # fit mg_model.fit_generator(cvpr2018_gen, initial_epoch=initial_epoch, epochs=nb_epochs, callbacks=[save_callback], steps_per_epoch=steps_per_epoch, verbose=1)
def train(model_dir, gpu_id, lr, n_iterations, alpha, image_sigma, model_save_iter, batch_size=1): """ model training function :param model_dir: model folder to save to :param gpu_id: integer specifying the gpu to use :param lr: learning rate :param n_iterations: number of training iterations :param alpha: the alpha, the scalar in front of the smoothing laplacian, in MICCAI paper :param image_sigma: the image sigma in MICCAI paper :param model_save_iter: frequency with which to save models :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size """ # prepare model folder if not os.path.isdir(model_dir): os.mkdir(model_dir) print(model_dir) # gpu handling gpu = '/gpu:' + str(gpu_id) os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True set_session(tf.Session(config=config)) # Diffeomorphic network architecture used in MICCAI 2018 paper nf_enc = [16, 32, 32, 32] nf_dec = [32, 32, 32, 32, 16, 3] # prepare the model # in the CVPR layout, the model takes in [image_1, image_2] and outputs [warped_image_1, velocity_stats] # in the experiments, we use image_2 as atlas with tf.device(gpu): # miccai 2018 used xy indexing. model = networks.miccai2018_net(vol_size, nf_enc, nf_dec, use_miccai_int=True, indexing='xy') # compile model_losses = [losses.kl_l2loss(image_sigma), losses.kl_loss(alpha)] model.compile(optimizer=Adam(lr=lr), loss=model_losses) # save first iteration model.save(os.path.join(model_dir, str(0) + '.h5')) train_example_gen = datagenerators.example_gen(train_vol_names) zeros = np.zeros((1, *vol_size, 3)) # train. Note: we use train_on_batch and design out own print function as this has enabled # faster development and debugging, but one could also use fit_generator and Keras callbacks. for step in range(1, n_iterations): # get_data X = next(train_example_gen)[0] # train with tf.device(gpu): train_loss = model.train_on_batch([X, atlas_vol], [atlas_vol, zeros]) if not isinstance(train_loss, list): train_loss = [train_loss] # print print_loss(step, 0, train_loss) # save model with tf.device(gpu): if (step % model_save_iter == 0) or step < 10: model.save(os.path.join(model_dir, str(step) + '.h5'))
def train( data_dir, val_data_dir, atlas_file, val_atlas_file, model, model_dir, gpu_id, lr, nb_epochs, reg_param, gama_param, steps_per_epoch, batch_size, load_model_file, data_loss, seg_dir=None, # one file val_seg_dir=None, Sf_file=None, # one file val_Sf_file=None, auxi_label=None, initial_epoch=0): """ model training function :param data_dir: folder with npz files for each subject. :param atlas_file: atlas filename. So far we support npz file with a 'vol' variable :param model: either vm1 or vm2 (based on CVPR 2018 paper) :param model_dir: the model directory to save to :param gpu_id: integer specifying the gpu to use :param lr: learning rate :param n_iterations: number of training iterations :param reg_param: the smoothness/reconstruction tradeoff parameter (lambda in CVPR paper) :param steps_per_epoch: frequency with which to save models :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size :param load_model_file: optional h5 model file to initialize with :param data_loss: 'mse' or 'ncc :param auxi_label: whether to use auxiliary informmation during the training """ # load atlas from provided files. The atlas we used is 160x192x224. # atlas_file = 'D:/voxel/data/t064.tif' atlas = Image.open(atlas_file) # is a TiffImageFile _size is (628, 690) atlas_vol = np.array(atlas)[ np.newaxis, ..., np.newaxis] # is a ndarray, shape is (1, 690, 628, 1) # new = Image.fromarray(X) new.size is (628, 690) vol_size = atlas_vol.shape[1:-1] # (690, 628) print(vol_size) val_atlas = Image.open( val_atlas_file) # is a TiffImageFile _size is (628, 690) val_atlas_vol = np.array(val_atlas)[ np.newaxis, ..., np.newaxis] # is a ndarray, shape is (1, 690, 628, 1) # new = Image.fromarray(X) new.size is (628, 690) val_vol_size = val_atlas_vol.shape[1:-1] # (690, 628) print(val_vol_size) Sm = Image.open(seg_dir) # is a TiffImageFile _size is (628, 690) Sm_ = np.array(Sm)[np.newaxis, ..., np.newaxis] val_Sm = Image.open(val_seg_dir) # is a TiffImageFile _size is (628, 690) val_Sm_ = np.array(val_Sm)[np.newaxis, ..., np.newaxis] # prepare data files # for the CVPR and MICCAI papers, we have data arranged in train/validate/test folders # inside each folder is a /vols/ and a /asegs/ folder with the volumes # and segmentations. All of our papers use npz formated data. # data_dir = D:/voxel/data/01 train_vol_names = data_dir # glob.glob(os.path.join(data_dir, '*.tif')) # is a list contain file path(name) # random.shuffle(train_vol_names) # shuffle volume list tif assert len(train_vol_names) > 0, "Could not find any training data" val_vol_names = val_data_dir # glob.glob(os.path.join(data_dir, '*.tif')) # is a list contain file path(name) # random.shuffle(train_vol_names) # shuffle volume list tif assert len(val_vol_names) > 0, "Could not find any training data" # UNET filters for voxelmorph-1 and voxelmorph-2, # these are architectures presented in CVPR 2018 nf_enc = [16, 32, 32, 32] if model == 'vm1': nf_dec = [32, 32, 32, 32, 8, 8] elif model == 'vm2': nf_dec = [32, 32, 32, 32, 32, 16, 16] else: # 'vm2double': nf_enc = [f * 2 for f in nf_enc] nf_dec = [f * 2 for f in [32, 32, 32, 32, 32, 16, 16]] assert data_loss in [ 'mse', 'cc', 'ncc' ], 'Loss should be one of mse or cc, found %s' % data_loss if data_loss in ['ncc', 'cc']: data_loss = losses.NCC().loss if Sf_file is not None: Sf = Image.open(Sf_file) Sf_ = np.array(Sf)[np.newaxis, ..., np.newaxis] if val_Sf_file is not None: val_Sf = Image.open(val_Sf_file) val_Sf_ = np.array(val_Sf)[np.newaxis, ..., np.newaxis] # prepare model folder if not os.path.isdir(model_dir): os.mkdir(model_dir) # GPU handling gpu = '/gpu:%d' % 0 # gpu_id os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True set_session(tf.Session(config=config)) #gpu = gpu_id # data generator nb_gpus = len(gpu_id.split(',')) # 1 assert np.mod(batch_size, nb_gpus) == 0, \ 'batch_size should be a multiple of the nr. of gpus. ' + \ 'Got batch_size %d, %d gpus' % (batch_size, nb_gpus) train_example_gen = datagenerators.example_gen( train_vol_names, batch_size=batch_size) # it is a list contain a ndarray atlas_vol_bs = np.repeat( atlas_vol, batch_size, axis=0) # is a ndarray, if batch_size is 2, shape is (2, 690, 628, 1) cvpr2018_gen = datagenerators.cvpr2018_gen(train_example_gen, atlas_vol_bs, batch_size=batch_size) val_example_gen = datagenerators.example_gen( val_vol_names, batch_size=batch_size) # it is a list contain a ndarray val_atlas_vol_bs = np.repeat( val_atlas_vol, batch_size, axis=0) # is a ndarray, if batch_size is 2, shape is (2, 690, 628, 1) val_cvpr2018_gen = datagenerators.cvpr2018_gen(val_example_gen, val_atlas_vol_bs, batch_size=batch_size) # prepare the model with tf.device(gpu): sess = tf.Session(config=tf.ConfigProto(log_device_placement=True)) # prepare the model # in the CVPR layout, the model takes in [image_1, image_2] and outputs [warped_image_1, flow] # in the experiments, we use image_2 as atlas model = networks.cvpr2018_net(vol_size, nf_enc, nf_dec) # load initial weights if load_model_file is not None: print('loading', load_model_file) model.load_weights(load_model_file) # save first iteration model.save(os.path.join(model_dir, '%02d.h5' % initial_epoch)) # if auxi_label is not None: # print('yes') # loss_model= [data_loss, losses.Grad('l2').loss, losses.Lseg()._lseg(Sf_) ] ########################## # loss_weight= [1.0, reg_param, gama_param] # else: loss_model = [ data_loss, losses.Grad(gama_param, Sf_, Sm_, penalty='l2').loss ] # real gama: reg_param*gama_param loss_weight = [1.0, reg_param] # reg_param_tensor = tf.constant(5, dtype=tf.float32) metrics_2 = losses.Grad(gama_param, val_Sf_, val_Sm_, penalty='l2', flag_vali=True).loss # reg_param # prepare callbacks save_file_name = os.path.join(model_dir, '{epoch:02d}.h5') # fit generator with tf.device(gpu): # multi-gpu support if nb_gpus > 1: save_callback = nrn_gen.ModelCheckpointParallel(save_file_name) mg_model = multi_gpu_model(model, gpus=nb_gpus) # single-gpu else: save_callback = ModelCheckpoint(save_file_name) mg_model = model # compile mg_model.compile(optimizer=Adam(lr=lr), loss=loss_model, loss_weights=loss_weight, metrics={'flow': metrics_2}) # fit history = mg_model.fit_generator(cvpr2018_gen, initial_epoch=initial_epoch, epochs=nb_epochs, callbacks=[save_callback], steps_per_epoch=steps_per_epoch, validation_data=val_cvpr2018_gen, validation_steps=1, verbose=2) # plot print('model', mg_model.metrics_names) print('keys()', history.history.keys()) # print(metrics.name) plt.plot(history.history['loss']) # plt.plot(history.history['val_spatial_transformer_1_loss']) plt.title('cvpr_auxi_loss') plt.ylabel('loss') plt.xlabel('Epoch') plt.legend(['Train', 'Validation']) plt.show()
def train(data_dir, atlas_file, model_dir, gpu_id, lr, nb_epochs, prior_lambda, image_sigma, steps_per_epoch, batch_size, load_model_file, bidir, initial_epoch=0): """ model training function :param data_dir: folder with npz files for each subject. :param atlas_file: atlas filename. So far we support npz file with a 'vol' variable :param model_dir: model folder to save to :param gpu_id: integer specifying the gpu to use :param lr: learning rate :param nb_epochs: number of training iterations :param prior_lambda: the prior_lambda, the scalar in front of the smoothing laplacian, in MICCAI paper :param image_sigma: the image sigma in MICCAI paper :param steps_per_epoch: frequency with which to save models :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size :param load_model_file: optional h5 model file to initialize with :param bidir: logical whether to use bidirectional cost function """ # load atlas from provided files. The atlas we used is 160x192x224. atlas_vol = np.load(atlas_file)['vol'][np.newaxis, ..., np.newaxis] vol_size = atlas_vol.shape[1:-1] # prepare data files # for the CVPR and MICCAI papers, we have data arranged in train/validate/test folders # inside each folder is a /vols/ and a /asegs/ folder with the volumes # and segmentations. All of our papers use npz formated data. train_vol_names = glob.glob(os.path.join(data_dir, '*.npz')) random.shuffle(train_vol_names) # shuffle volume list assert len(train_vol_names) > 0, "Could not find any training data" # Diffeomorphic network architecture used in MICCAI 2018 paper nf_enc = [16, 32, 32, 32] nf_dec = [32, 32, 32, 32, 16, 3] # prepare model folder if not os.path.isdir(model_dir): os.mkdir(model_dir) # gpu handling gpu = '/gpu:%d' % 0 # gpu_id os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True set_session(tf.Session(config=config)) # prepare the model with tf.device(gpu): # the MICCAI201 model takes in [image_1, image_2] and outputs [warped_image_1, velocity_stats] # in these experiments, we use image_2 as atlas model = networks.miccai2018_net(vol_size, nf_enc, nf_dec, bidir=bidir) # load initial weights if load_model_file is not None and load_model_file != "": model.load_weights(load_model_file) # save first iteration model.save(os.path.join(model_dir, '%02d.h5' % initial_epoch)) # compile # note: best to supply vol_shape here than to let tf figure it out. flow_vol_shape = model.outputs[-1].shape[1:-1] loss_class = losses.Miccai2018(image_sigma, prior_lambda, flow_vol_shape=flow_vol_shape) if bidir: model_losses = [ loss_class.recon_loss, loss_class.recon_loss, loss_class.kl_loss ] loss_weights = [0.5, 0.5, 1] else: model_losses = [loss_class.recon_loss, loss_class.kl_loss] loss_weights = [1, 1] # data generator nb_gpus = len(gpu_id.split(',')) assert np.mod(batch_size, nb_gpus) == 0, \ 'batch_size should be a multiple of the nr. of gpus. ' + \ 'Got batch_size %d, %d gpus' % (batch_size, nb_gpus) train_example_gen = datagenerators.example_gen(train_vol_names, batch_size=batch_size) atlas_vol_bs = np.repeat(atlas_vol, batch_size, axis=0) miccai2018_gen = datagenerators.miccai2018_gen(train_example_gen, atlas_vol_bs, batch_size=batch_size, bidir=bidir) # prepare callbacks save_file_name = os.path.join(model_dir, '{epoch:02d}.h5') # fit generator with tf.device(gpu): # multi-gpu support if nb_gpus > 1: save_callback = nrn_gen.ModelCheckpointParallel(save_file_name) mg_model = multi_gpu_model(model, gpus=nb_gpus) # single gpu else: save_callback = ModelCheckpoint(save_file_name) mg_model = model mg_model.compile(optimizer=Adam(lr=lr), loss=model_losses, loss_weights=loss_weights) mg_model.fit_generator(miccai2018_gen, initial_epoch=initial_epoch, epochs=nb_epochs, callbacks=[save_callback], steps_per_epoch=steps_per_epoch, verbose=1)
def train(model, model_dir, gpu_id, lr, n_iterations, reg_param, model_save_iter, batch_size=1): """ model training function :param model: either vm1 or vm2 (based on CVPR 2018 paper) :param model_dir: the model directory to save to :param gpu_id: integer specifying the gpu to use :param lr: learning rate :param n_iterations: number of training iterations :param reg_param: the smoothness/reconstruction tradeoff parameter (lambda in CVPR paper) :param model_save_iter: frequency with which to save models :param batch_size: Optional, default of 1. can be larger, depends on GPU memory and volume size """ # prepare model folder if not os.path.isdir(model_dir): os.mkdir(model_dir) # GPU handling os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True set_session(tf.Session(config=config)) # UNET filters for voxelmorph-1 and voxelmorph-2, # these are architectures presented in CVPR 2018 nf_enc = [16, 32, 32, 32] if model == 'vm1': nf_dec = [32, 32, 32, 32, 8, 8] else: nf_dec = [32, 32, 32, 32, 32, 16, 16] # prepare the model # in the CVPR layout, the model takes in [image_1, image_2] and outputs [warped_image_1, flow] # in the experiments, we use image_2 as atlas model = networks.unet(vol_size, nf_enc, nf_dec) model.compile(optimizer=Adam(lr=lr), loss=[losses.cc3D(), losses.gradientLoss('l2')], loss_weights=[1.0, reg_param]) # if you'd like to initialize the data, you can do it here: # model.load_weights(os.path.join(model_dir, '120000.h5')) # prepare data for training train_example_gen = datagenerators.example_gen(train_vol_names) zero_flow = np.zeros([batch_size, *vol_size, 3]) # train. Note: we use train_on_batch and design out own print function as this has enabled # faster development and debugging, but one could also use fit_generator and Keras callbacks. for step in range(0, n_iterations): # get data X = next(train_example_gen)[0] # train train_loss = model.train_on_batch([X, atlas_vol], [atlas_vol, zero_flow]) if not isinstance(train_loss, list): train_loss = [train_loss] # print the loss. print_loss(step, 1, train_loss) # save model if step % model_save_iter == 0: model.save(os.path.join(model_dir, str(step) + '.h5'))
def train_unsupervised_segmentation(data_dir, atlas_file, mapping_file, model, model_dir, gpu_id, lr, nb_epochs, init_stats, reg_param, steps_per_epoch, batch_size, stat_post_warp, warp_method, load_model_file, initial_epoch=0): """ model training function :param data_dir: folder with npz files (coregistered, intensity normalized) :param atlas_file: file with probabilistic atlas (coregistered to images) :param mapping_file: file with mapping from labels to tissue types :param model: registration (voxelmorph) model: vm1, vm2, or vm2double :param model_dir: the model directory to save to :param gpu_id: integer specifying the gpu to use :param lr: learning rate :param nb_epochs: number of epochs :param init_stats: file with guesses for means and log-variances (vectors init_mu, init_sigma) :param reg_param: smoothness/reconstruction tradeoff parameter (lambda in the paper) :param steps_per_epoch: frequency with which to save models :param batch_size: default of 1. can be larger, depends on GPU memory and volume size :param stat_post_warp: set to 1 to use warped atlas to estimate Gaussian parameters :param warp_method: set to 'WARP' if you want to warp the atlas :param load_model_file: optional h5 model file to initialize with :param initial_epoch: initial epoch """ # load reference soft edge and corresponding mask from provided files # (we used size 160x192x224). # Also: group labels in tissue types, if necessary if mapping_file is None: atlas_vol = np.load(atlas_file)['vol_data'][np.newaxis, ...] nb_labels = atlas_vol.shape[-1] else: atlas_full = np.load(atlas_file)['vol_data'][np.newaxis, ...] mapping = np.load(mapping_file)['mapping'].astype('int').flatten() assert len(mapping) == atlas_full.shape[-1], \ 'mapping shape %d is inconsistent with atlas shape %d' % (len(mapping), atlas_full.shape[-1]) nb_labels = 1 + np.max(mapping) atlas_vol = np.zeros( [1, *atlas_full.shape[1:-1], nb_labels.astype('int')]) for j in range(np.max(mapping.shape)): atlas_vol[0, ..., mapping[j]] = atlas_vol[0, ..., mapping[j]] + atlas_full[0, ..., j] vol_size = atlas_vol.shape[1:-1] # load guesses for means and variances init_mu = np.load(init_stats)['init_mu'] init_sigma = np.load(init_stats)['init_std'] # prepare data files train_vol_names = glob.glob(os.path.join(data_dir, '*.npz')) random.shuffle(train_vol_names) assert len(train_vol_names) > 0, "Could not find any training data" # UNET filters for voxelmorph-1 and voxelmorph-2, # these are architectures presented in CVPR 2018 nf_enc = [16, 32, 32, 32] if model == 'vm1': nf_dec = [32, 32, 32, 32, 8, 8] elif model == 'vm2': nf_dec = [32, 32, 32, 32, 32, 16, 16] else: # 'vm2double': nf_enc = [f * 2 for f in nf_enc] nf_dec = [f * 2 for f in [32, 32, 32, 32, 32, 16, 16]] # prepare model and log folders if not os.path.isdir(model_dir): os.mkdir(model_dir) log_dir = os.path.join(model_dir, 'logs') if not os.path.isdir(log_dir): os.mkdir(log_dir) # GPU handling gpu = '/gpu:%d' % 0 # gpu_id os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True set_session(tf.Session(config=config)) # prepare the model with tf.device(gpu): # prepare the model model = networks.cvpr2018_net_probatlas(vol_size, nf_enc, nf_dec, nb_labels, diffeomorphic=True, full_size=False, stat_post_warp=stat_post_warp, warp_method=warp_method, init_mu=init_mu, init_sigma=init_sigma) # load initial weights if load_model_file is not None: print('loading', load_model_file) model.load_weights(load_model_file) # save first iteration model.save(os.path.join(model_dir, '%02d.h5' % initial_epoch)) # data generator nb_gpus = len(gpu_id.split(',')) assert np.mod(batch_size, nb_gpus) == 0, \ 'batch_size should be a multiple of the nr. of gpus. ' + \ 'Got batch_size %d, %d gpus' % (batch_size, nb_gpus) train_example_gen = datagenerators.example_gen(train_vol_names, batch_size=batch_size) atlas_vol_bs = np.repeat(atlas_vol, batch_size, axis=0) cvpr2018_gen = datagenerators.cvpr2018_gen(train_example_gen, atlas_vol_bs, batch_size=batch_size) # prepare callbacks save_file_name = os.path.join(model_dir, '{epoch:02d}.h5') # fit generator with tf.device(gpu): # multi-gpu support if nb_gpus > 1: save_callback = nrn_gen.ModelCheckpointParallel(save_file_name) mg_model = multi_gpu_model(model, gpus=nb_gpus) # single-gpu else: save_callback = ModelCheckpoint(save_file_name) mg_model = model # tensorBoard callback tensorboard = TensorBoard(log_dir=log_dir, histogram_freq=0, write_graph=True, write_images=False) # compile loss and parameters def data_loss(_, yp): m = tf.cast(model.inputs[0] > 0, tf.float32) return -K.sum(yp * m) / K.sum(m) if warp_method != 'WARP': reg_param = 0 # compile mg_model.compile(optimizer=Adam(lr=lr), loss=[data_loss, losses.Grad('l2').loss], loss_weights=[1.0, reg_param]) # fit mg_model.fit_generator(cvpr2018_gen, initial_epoch=initial_epoch, epochs=nb_epochs, callbacks=[save_callback, tensorboard], steps_per_epoch=steps_per_epoch, verbose=1)