def entry_func(args=None): # Project base path args = vars(get_argparser().parse_args(args)) basedir = os.path.abspath(args["project_dir"]) overwrite = args["overwrite"] continue_training = args["continue_training"] eval_prob = args["eval_prob"] await_PID = args["wait_for"] dice_weight = args["dice_weight"] print("Fitting fusion model for project-folder: %s" % basedir) # Minimum images in validation set before also using training images min_val_images = 15 # Fusion model training params epochs = args['epochs'] fm_batch_size = args["batch_size"] # Early stopping params early_stopping = args["early_stopping"] # Wait for PID? if await_PID: from MultiPlanarUNet.utils import await_PIDs await_PIDs(await_PID) # Fetch GPU(s) num_GPUs = args["num_GPUs"] force_gpu = args["force_GPU"] # Wait for free GPU if not force_gpu: await_and_set_free_gpu(N=num_GPUs, sleep_seconds=120) num_GPUs = 1 else: set_gpu(force_gpu) num_GPUs = len(force_gpu.split(",")) # Get logger logger = Logger(base_path=basedir, active_file="train_fusion", overwrite_existing=overwrite) # Get YAML hyperparameters hparams = YAMLHParams(os.path.join(basedir, "train_hparams.yaml")) # Get some key settings n_classes = hparams["build"]["n_classes"] if hparams["build"]["out_activation"] == "linear": # Trained with logit targets? hparams["build"][ "out_activation"] = "softmax" if n_classes > 1 else "sigmoid" # Get views views = np.load("%s/views.npz" % basedir)["arr_0"] del hparams["fit"]["views"] # Get weights and set fusion (output) path weights = get_best_model("%s/model" % basedir) weights_name = os.path.splitext(os.path.split(weights)[-1])[0] fusion_weights = "%s/model/fusion_weights/" \ "%s_fusion_weights.h5" % (basedir, weights_name) create_folders(os.path.split(fusion_weights)[0]) # Log a few things log(logger, hparams, views, weights, fusion_weights) # Check if exists already... if not overwrite and os.path.exists(fusion_weights): from sys import exit print("\n[*] A fusion weights file already exists at '%s'." "\n Use the --overwrite flag to overwrite." % fusion_weights) exit(0) # Load validation data images = ImagePairLoader(**hparams["val_data"], logger=logger) is_validation = {m.id: True for m in images} # Define random sets of images to train on simul. (cant be all due # to memory constraints) image_IDs = [m.id for m in images] if len(images) < min_val_images: # Pick N random training images diff = min_val_images - len(images) logger("Adding %i training images to set" % diff) # Load the training data and pick diff images train = ImagePairLoader(**hparams["train_data"], logger=logger) indx = np.random.choice(np.arange(len(train)), diff, replace=diff > len(train)) # Add the images to the image set set train_add = [train[i] for i in indx] for m in train_add: is_validation[m.id] = False image_IDs.append(m.id) images.add_images(train_add) # Append to length % sub_size == 0 sub_size = args["images_per_round"] rest = int(sub_size * np.ceil(len(image_IDs) / sub_size)) - len(image_IDs) if rest: image_IDs += list(np.random.choice(image_IDs, rest, replace=False)) # Shuffle and split random.shuffle(image_IDs) sets = [ set(s) for s in np.array_split(image_IDs, len(image_IDs) / sub_size) ] assert (contains_all_images(sets, image_IDs)) # Define fusion model (named 'org' to store reference to orgiginal model if # multi gpu model is created below) fusion_model_org = FusionModel(n_inputs=len(views), n_classes=n_classes, weight=dice_weight, logger=logger, verbose=False) if continue_training: fusion_model_org.load_weights(fusion_weights) print("\n[OBS] CONTINUED TRAINING FROM:\n", fusion_weights) # Define model unet = init_model(hparams["build"], logger) print("\n[*] Loading weights: %s\n" % weights) unet.load_weights(weights, by_name=True) if num_GPUs > 1: from tensorflow.keras.utils import multi_gpu_model # Set for predictor model n_classes = n_classes unet = multi_gpu_model(unet, gpus=num_GPUs) unet.n_classes = n_classes # Set for fusion model fusion_model = multi_gpu_model(fusion_model_org, gpus=num_GPUs) else: fusion_model = fusion_model_org # Compile the model logger("Compiling...") metrics = [ "sparse_categorical_accuracy", sparse_fg_precision, sparse_fg_recall ] fusion_model.compile(optimizer=Adam(lr=1e-3), loss=fusion_model_org.loss, metrics=metrics) fusion_model_org._log() try: _run_fusion_training(sets, logger, hparams, min_val_images, is_validation, views, n_classes, unet, fusion_model_org, fusion_model, early_stopping, fm_batch_size, epochs, eval_prob, fusion_weights) except KeyboardInterrupt: pass finally: if not os.path.exists(os.path.split(fusion_weights)[0]): os.mkdir(os.path.split(fusion_weights)[0]) # Save fusion model weights # OBS: Must be original model if multi-gpu is performed! fusion_model_org.save_weights(fusion_weights)
def _run_fusion_training(sets, logger, hparams, min_val_images, is_validation, views, n_classes, unet, fusion_model_org, fusion_model, early_stopping, fm_batch_size, epochs, eval_prob, fusion_weights_path): for _round, _set in enumerate(sets): s = "Set %i/%i:\n%s" % (_round + 1, len(sets), _set) logger("\n%s" % highlighted(s)) # Reload data images = ImagePairLoader(**hparams["val_data"]) if len(images) < min_val_images: images.add_images(ImagePairLoader(**hparams["train_data"])) # Get list of ImagePair objects to run on image_set_dict = {m.id: m for m in images if m.id in _set} # Fetch points from the set images points_collection = [] targets_collection = [] N_im = len(image_set_dict) for num_im, image_id in enumerate(list(image_set_dict.keys())): logger("") logger( highlighted("(%i/%i) Running on %s (%s)" % (num_im + 1, N_im, image_id, "val" if is_validation[image_id] else "train"))) # Set the current ImagePair image = image_set_dict[image_id] images.images = [image] # Load views kwargs = hparams["fit"] kwargs.update(hparams["build"]) seq = images.get_sequencer(views=views, **kwargs) # Get voxel grid in real space voxel_grid_real_space = get_voxel_grid_real_space(image) # Get array to store predictions across all views targets = image.labels.reshape(-1, 1) points = np.empty(shape=(len(targets), len(views), n_classes), dtype=np.float32) points.fill(np.nan) # Predict on all views for k, v in enumerate(views): print("\n%s" % "View: %s" % v) points[:, k, :] = predict_and_map( model=unet, seq=seq, image=image, view=v, voxel_grid_real_space=voxel_grid_real_space, n_planes='same+20', targets=targets, eval_prob=eval_prob).reshape(-1, n_classes) # Clean up a bit del image_set_dict[image_id] del image # Should be GC at this point anyway # add to collections points_collection.append(points) targets_collection.append(targets) # Stack points into one matrix logger("Stacking points...") X, y = stack_collections(points_collection, targets_collection) # Shuffle train print("Shuffling points...") X, y = shuffle(X, y) print("Getting validation set...") val_ind = int(0.20 * X.shape[0]) X_val, y_val = X[:val_ind], y[:val_ind] X, y = X[val_ind:], y[val_ind:] # Prepare dice score callback for validation data val_cb = ValDiceScores((X_val, y_val), n_classes, 50000, logger) # Callbacks cbs = [ val_cb, CSVLogger(filename="logs/fusion_training.csv", separator=",", append=True), PrintLayerWeights(fusion_model_org.layers[-1], every=1, first=1000, per_epoch=True, logger=logger) ] es = EarlyStopping(monitor='val_dice', min_delta=0.0, patience=early_stopping, verbose=1, mode='max') cbs.append(es) # Start training try: fusion_model.fit(X, y, batch_size=fm_batch_size, epochs=epochs, callbacks=cbs, verbose=1) except KeyboardInterrupt: pass fusion_model_org.save_weights(fusion_weights_path)