def main():
    parser = argparse.ArgumentParser(
        description='Train a mitochondria segmentation model')
    parser.add_argument(
        'images',
        type=str,
        help='Folder location of the image stacks for training')
    parser.add_argument(
        'labels',
        type=str,
        help='Folder location of the label stacks for training')

    args = parser.parse_args()
    train_stacks_dir = args.images
    label_stacks_dir = args.labels

    print("Starting mitochondria (competition) training script...")
    print(f"Training stack folder: {train_stacks_dir}")
    print(f"Label stack folder: {label_stacks_dir}")
    print()

    model = HUNet(PATCH_SHAPE[0], 32)
    optimizer = torch.optim.Adam(model.parameters(), lr=MITO_LEARNING_RATE)
    train_model(model=model,
                model_name='mito_8nm_rat_area',
                loss_func=dice_loss,
                optimizer=optimizer,
                image_in_folder=train_stacks_dir,
                label_in_folder=label_stacks_dir,
                train_stacks=MITO_COMPETITION_RAT_TRAIN_STACKS,
                validation_stacks=MITO_COMPETITION_RAT_VALIDATION_STACKS,
                patch_shape=PATCH_SHAPE,
                batch_size=12,
                epochs=1000,
                label_postfix='_area')
def main():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('images', type=str)
    parser.add_argument('labels', type=str)
    parser.add_argument('save_folder', type=str)

    args = parser.parse_args()
    image_stacks_folder = args.images
    label_stacks_folder = args.labels
    save_folder = args.save_folder

    for mito_test_stack in ["mito_r_test"]:
        print(f"loading {mito_test_stack} for prediction")
        image_stack = read_tiff(image_stacks_folder, mito_test_stack)

        print("running boundary predictions")
        model = HUNet(PATCH_SHAPE[0], 32)
        restore_latest_model_epoch(model, "mito_8nm_rat_boundary")

        try:
            label_stack = read_tiff(label_stacks_folder,
                                    mito_test_stack + '_boundary')
        except FileNotFoundError:
            label_stack = None

        segment_save_and_print_metrics(model=model,
                                       stack_name=mito_test_stack,
                                       save_folder=save_folder,
                                       image_stack=image_stack,
                                       benchmark_segmentations=label_stack,
                                       save_postfix="_boundary")

        print("running area predictions")

        model = HUNet(PATCH_SHAPE[0], 32)
        restore_latest_model_epoch(model, "mito_8nm_rat_area")

        try:
            label_stack = read_tiff(label_stacks_folder,
                                    mito_test_stack + '_area')
        except FileNotFoundError:
            label_stack = None

        segment_save_and_print_metrics(model=model,
                                       stack_name=mito_test_stack,
                                       save_folder=save_folder,
                                       image_stack=image_stack,
                                       benchmark_segmentations=label_stack,
                                       save_postfix="_area")
    def test_hunet_forward(self):
        in_channels = 12
        in_xy_dim = 256
        in_shape = (1, in_channels, in_xy_dim, in_xy_dim)
        xx = torch.randn(in_shape)

        model = HUNet(in_channels)
        tic = time.time()
        out = model(xx)
        toc = time.time()
        print(f'\n*** HUNet forward time: {toc-tic} ***')
        assert out.shape == in_shape