def load_data(args): """ Modify this to load your data and labels """ validation_data_params = { "dim": (args.patch_dim, args.patch_dim, args.patch_dim), "batch_size": 8, "n_in_channels": args.number_input_channels, "n_out_channels": 1, "train_test_split": args.train_test_split, "augment": False, "shuffle": False, "seed": args.random_seed } testing_generator = DataGenerator("train", args.data_path, **validation_data_params) testing_generator.print_info() batch_idx = 0 imgs, msks = testing_generator.get_batch(batch_idx) fileIDs = testing_generator.get_batch_fileIDs(batch_idx) """ OpenVINO uses channels first tensors (NCHWD). TensorFlow usually does channels last (NHWDC). So we need to transpose the axes. """ imgs = imgs.transpose((0, 4, 1, 2, 3)) msks = msks.transpose((0, 4, 1, 2, 3)) return imgs, msks, fileIDs
training_data_params = { "dim": (args.patch_height, args.patch_width, args.patch_depth), "batch_size": args.bz, "n_in_channels": args.number_input_channels, "n_out_channels": 1, "train_test_split": args.train_test_split, "validate_test_split": args.validate_test_split, "augment": True, "shuffle": True, "seed": hvd.rank() } training_generator = DataGenerator("train", args.data_path, **training_data_params) if (hvd.rank() == 0): training_generator.print_info() validation_data_params = { "dim": (args.patch_height, args.patch_width, args.patch_depth), "batch_size": 1, "n_in_channels": args.number_input_channels, "n_out_channels": 1, "train_test_split": args.train_test_split, "validate_test_split": args.validate_test_split, "augment": False, "shuffle": False, "seed": args.random_seed } validation_generator = DataGenerator("validate", args.data_path, **validation_data_params)
unet_model = unet(channels_last = True) # channels first or last model = K.models.load_model(args.saved_model, custom_objects=unet_model.custom_objects) print("Loading images and masks from test set") validation_data_params = {"dim": (args.patch_height, args.patch_width, args.patch_depth), "batch_size": 1, "n_in_channels": args.number_input_channels, "n_out_channels": 1, "train_test_split": args.train_test_split, "augment": False, "shuffle": False, "seed": args.random_seed} testing_generator = DataGenerator("test", args.data_path, **validation_data_params) testing_generator.print_info() m = model.evaluate_generator(testing_generator, verbose=1, max_queue_size=args.num_prefetched_batches, workers=args.num_data_loaders, use_multiprocessing=False) print("\n\nTest metrics") print("============") for idx, name in enumerate(model.metrics_names): print("{} = {:.4f}".format(name, m[idx])) save_directory = "predictions_directory" try: os.stat(save_directory)