def main(): parser = argparse.ArgumentParser( description='Train a mitochondria segmentation model') parser.add_argument( 'images', type=str, help='Folder location of the image stacks for training') parser.add_argument( 'labels', type=str, help='Folder location of the label stacks for training') args = parser.parse_args() train_stacks_dir = args.images label_stacks_dir = args.labels print("Starting mitochondria (competition) training script...") print(f"Training stack folder: {train_stacks_dir}") print(f"Label stack folder: {label_stacks_dir}") print() model = HUNet(PATCH_SHAPE[0], 32) optimizer = torch.optim.Adam(model.parameters(), lr=MITO_LEARNING_RATE) train_model(model=model, model_name='mito_8nm_rat_area', loss_func=dice_loss, optimizer=optimizer, image_in_folder=train_stacks_dir, label_in_folder=label_stacks_dir, train_stacks=MITO_COMPETITION_RAT_TRAIN_STACKS, validation_stacks=MITO_COMPETITION_RAT_VALIDATION_STACKS, patch_shape=PATCH_SHAPE, batch_size=12, epochs=1000, label_postfix='_area')
def main(): parser = argparse.ArgumentParser(description='') parser.add_argument('images', type=str) parser.add_argument('labels', type=str) parser.add_argument('save_folder', type=str) args = parser.parse_args() image_stacks_folder = args.images label_stacks_folder = args.labels save_folder = args.save_folder for mito_test_stack in ["mito_r_test"]: print(f"loading {mito_test_stack} for prediction") image_stack = read_tiff(image_stacks_folder, mito_test_stack) print("running boundary predictions") model = HUNet(PATCH_SHAPE[0], 32) restore_latest_model_epoch(model, "mito_8nm_rat_boundary") try: label_stack = read_tiff(label_stacks_folder, mito_test_stack + '_boundary') except FileNotFoundError: label_stack = None segment_save_and_print_metrics(model=model, stack_name=mito_test_stack, save_folder=save_folder, image_stack=image_stack, benchmark_segmentations=label_stack, save_postfix="_boundary") print("running area predictions") model = HUNet(PATCH_SHAPE[0], 32) restore_latest_model_epoch(model, "mito_8nm_rat_area") try: label_stack = read_tiff(label_stacks_folder, mito_test_stack + '_area') except FileNotFoundError: label_stack = None segment_save_and_print_metrics(model=model, stack_name=mito_test_stack, save_folder=save_folder, image_stack=image_stack, benchmark_segmentations=label_stack, save_postfix="_area")
def test_hunet_forward(self): in_channels = 12 in_xy_dim = 256 in_shape = (1, in_channels, in_xy_dim, in_xy_dim) xx = torch.randn(in_shape) model = HUNet(in_channels) tic = time.time() out = model(xx) toc = time.time() print(f'\n*** HUNet forward time: {toc-tic} ***') assert out.shape == in_shape