Esempio n. 1
0
def main():

    parser = argparse.ArgumentParser(description='cellpose parameters')
    parser.add_argument('--check_mkl',
                        action='store_true',
                        help='check if mkl working')
    parser.add_argument(
        '--mkldnn',
        action='store_true',
        help='for mxnet, force MXNET_SUBGRAPH_BACKEND = "MKLDNN"')
    parser.add_argument('--train',
                        action='store_true',
                        help='train network using images in dir')
    parser.add_argument('--dir',
                        required=False,
                        default=[],
                        type=str,
                        help='folder containing data to run or train on')
    parser.add_argument('--look_one_level_down', action='store_true', help='')
    parser.add_argument('--mxnet', action='store_true', help='use mxnet')
    parser.add_argument('--img_filter',
                        required=False,
                        default=[],
                        type=str,
                        help='end string for images to run on')
    parser.add_argument('--use_gpu',
                        action='store_true',
                        help='use gpu if mxnet with cuda installed')
    parser.add_argument(
        '--fast_mode',
        action='store_true',
        help="make code run faster by turning off 4 network averaging")
    parser.add_argument(
        '--resample',
        action='store_true',
        help=
        "run dynamics on full image (slower for images with large diameters)")
    parser.add_argument(
        '--no_interp',
        action='store_true',
        help='do not interpolate when running dynamics (was default)')
    # settings for running cellpose
    parser.add_argument(
        '--do_3D',
        action='store_true',
        help='process images as 3D stacks of images (nplanes x nchan x Ly x Lx'
    )
    parser.add_argument('--pretrained_model',
                        required=False,
                        default='cyto',
                        type=str,
                        help='model to use')
    parser.add_argument(
        '--chan',
        required=False,
        default=0,
        type=int,
        help='channel to segment; 0: GRAY, 1: RED, 2: GREEN, 3: BLUE')
    parser.add_argument(
        '--chan2',
        required=False,
        default=0,
        type=int,
        help=
        'nuclear channel (if cyto, optional); 0: NONE, 1: RED, 2: GREEN, 3: BLUE'
    )
    parser.add_argument('--invert',
                        required=False,
                        action='store_true',
                        help='invert grayscale channel')
    parser.add_argument(
        '--all_channels',
        action='store_true',
        help=
        'use all channels in image if using own model and images with special channels'
    )
    parser.add_argument(
        '--diameter',
        required=False,
        default=30.,
        type=float,
        help='cell diameter, if 0 cellpose will estimate for each image')
    parser.add_argument(
        '--stitch_threshold',
        required=False,
        default=0.0,
        type=float,
        help=
        'compute masks in 2D then stitch together masks with IoU>0.9 across planes'
    )
    parser.add_argument(
        '--flow_threshold',
        required=False,
        default=0.4,
        type=float,
        help='flow error threshold, 0 turns off this optional QC step')
    parser.add_argument('--cellprob_threshold',
                        required=False,
                        default=0.0,
                        type=float,
                        help='cell probability threshold, centered at 0.0')
    parser.add_argument('--save_png',
                        action='store_true',
                        help='save masks as png')
    parser.add_argument('--save_outlines',
                        action='store_true',
                        help='save outlines as text file for ImageJ')
    parser.add_argument('--save_tif',
                        action='store_true',
                        help='save masks as tif')
    parser.add_argument('--no_npy',
                        action='store_true',
                        help='suppress saving of npy')
    parser.add_argument(
        '--channel_axis',
        required=False,
        default=None,
        type=int,
        help='axis of image which corresponds to image channels')
    parser.add_argument('--z_axis',
                        required=False,
                        default=None,
                        type=int,
                        help='axis of image which corresponds to Z dimension')
    parser.add_argument('--exclude_on_edges',
                        action='store_true',
                        help='discard masks which touch edges of image')
    parser.add_argument(
        '--unet',
        required=False,
        default=0,
        type=int,
        help='run standard unet instead of cellpose flow output')
    parser.add_argument(
        '--nclasses',
        required=False,
        default=3,
        type=int,
        help='if running unet, choose 2 or 3, otherwise not used')

    # settings for training
    parser.add_argument('--train_size',
                        action='store_true',
                        help='train size network at end of training')
    parser.add_argument('--mask_filter',
                        required=False,
                        default='_masks',
                        type=str,
                        help='end string for masks to run on')
    parser.add_argument('--test_dir',
                        required=False,
                        default=[],
                        type=str,
                        help='folder containing test data (optional)')
    parser.add_argument('--learning_rate',
                        required=False,
                        default=0.2,
                        type=float,
                        help='learning rate')
    parser.add_argument('--n_epochs',
                        required=False,
                        default=500,
                        type=int,
                        help='number of epochs')
    parser.add_argument('--batch_size',
                        required=False,
                        default=8,
                        type=int,
                        help='batch size')
    parser.add_argument('--residual_on',
                        required=False,
                        default=1,
                        type=int,
                        help='use residual connections')
    parser.add_argument('--style_on',
                        required=False,
                        default=1,
                        type=int,
                        help='use style vector')
    parser.add_argument(
        '--concatenation',
        required=False,
        default=0,
        type=int,
        help=
        'concatenate downsampled layers with upsampled layers (off by default which means they are added)'
    )

    args = parser.parse_args()

    if args.check_mkl:
        mkl_enabled = models.check_mkl((not args.mxnet))
    else:
        mkl_enabled = True

    if not args.train and (mkl_enabled and args.mkldnn):
        os.environ["MXNET_SUBGRAPH_BACKEND"] = "MKLDNN"
    else:
        os.environ["MXNET_SUBGRAPH_BACKEND"] = ""

    if len(args.dir) == 0:
        if not GUI_ENABLED:
            logger.critical('ERROR: %s' % GUI_ERROR)
            if GUI_IMPORT:
                logger.critical(
                    'GUI FAILED: GUI dependencies may not be installed, to install, run'
                )
                logger.critical('     pip install cellpose[gui]')
        else:
            gui.run()

    else:
        use_gpu = False
        channels = [args.chan, args.chan2]

        # find images
        if len(args.img_filter) > 0:
            imf = args.img_filter
        else:
            imf = None

        device, gpu = models.assign_device((not args.mxnet), args.use_gpu)

        if not args.train and not args.train_size:
            tic = time.time()
            if not (args.pretrained_model == 'cyto' or args.pretrained_model
                    == 'nuclei' or args.pretrained_model == 'cyto2'):
                cpmodel_path = args.pretrained_model
                if not os.path.exists(cpmodel_path):
                    logger.warning(
                        'model path does not exist, using cyto model')
                    args.pretrained_model = 'cyto'

            image_names = io.get_image_files(
                args.dir,
                args.mask_filter,
                imf=imf,
                look_one_level_down=args.look_one_level_down)
            nimg = len(image_names)

            cstr0 = ['GRAY', 'RED', 'GREEN', 'BLUE']
            cstr1 = ['NONE', 'RED', 'GREEN', 'BLUE']
            logger.info(
                '>>>> running cellpose on %d images using chan_to_seg %s and chan (opt) %s'
                % (nimg, cstr0[channels[0]], cstr1[channels[1]]))

            if args.pretrained_model == 'cyto' or args.pretrained_model == 'nuclei' or args.pretrained_model == 'cyto2':
                if args.mxnet and args.pretrained_model == 'cyto2':
                    logger.warning(
                        'cyto2 model not available in mxnet, using cyto model')
                    args.pretrained_model = 'cyto'
                model = models.Cellpose(gpu=gpu,
                                        device=device,
                                        model_type=args.pretrained_model,
                                        torch=(not args.mxnet))
            else:
                if args.all_channels:
                    channels = None
                model = models.CellposeModel(gpu=gpu,
                                             device=device,
                                             pretrained_model=cpmodel_path,
                                             torch=(not args.mxnet))

            if args.diameter == 0:
                if args.pretrained_model == 'cyto' or args.pretrained_model == 'nuclei' or args.pretrained_model == 'cyto2':
                    diameter = None
                    logger.info('>>>> estimating diameter for each image')
                else:
                    logger.info(
                        '>>>> using user-specified model, no auto-diameter estimation available'
                    )
                    diameter = model.diam_mean
            else:
                diameter = args.diameter
                logger.info('>>>> using diameter %0.2f for all images' %
                            diameter)

            tqdm_out = utils.TqdmToLogger(logger, level=logging.INFO)
            for image_name in tqdm(image_names, file=tqdm_out):
                image = io.imread(image_name)
                out = model.eval(image,
                                 channels=channels,
                                 diameter=diameter,
                                 do_3D=args.do_3D,
                                 net_avg=(not args.fast_mode),
                                 augment=False,
                                 resample=args.resample,
                                 flow_threshold=args.flow_threshold,
                                 cellprob_threshold=args.cellprob_threshold,
                                 invert=args.invert,
                                 batch_size=args.batch_size,
                                 interp=(not args.no_interp),
                                 channel_axis=args.channel_axis,
                                 z_axis=args.z_axis)
                masks, flows = out[:2]
                if len(out) > 3:
                    diams = out[-1]
                else:
                    diams = diameter
                if args.exclude_on_edges:
                    masks = utils.remove_edge_masks(masks)
                if not args.no_npy:
                    io.masks_flows_to_seg(image, masks, flows, diams,
                                          image_name, channels)
                if args.save_png or args.save_tif or args.save_outlines:
                    io.save_masks(image,
                                  masks,
                                  flows,
                                  image_name,
                                  png=args.save_png,
                                  tif=args.save_tif,
                                  outlines=args.save_outlines)
            logger.info('>>>> completed in %0.3f sec' % (time.time() - tic))
        else:
            if args.pretrained_model == 'cyto' or args.pretrained_model == 'nuclei' or args.pretrained_model == 'cyto2':
                if args.mxnet and args.pretrained_model == 'cyto2':
                    logger.warning(
                        'cyto2 model not available in mxnet, using cyto model')
                    args.pretrained_model = 'cyto'
                cpmodel_path = models.model_path(args.pretrained_model, 0,
                                                 not args.mxnet)
                if args.pretrained_model == 'cyto':
                    szmean = 30.
                else:
                    szmean = 17.
            else:
                cpmodel_path = os.fspath(args.pretrained_model)
                szmean = 30.

            test_dir = None if len(args.test_dir) == 0 else args.test_dir
            output = io.load_train_test_data(args.dir, test_dir, imf,
                                             args.mask_filter, args.unet,
                                             args.look_one_level_down)
            images, labels, image_names, test_images, test_labels, image_names_test = output

            # training with all channels
            if args.all_channels:
                img = images[0]
                if img.ndim == 3:
                    nchan = min(img.shape)
                elif img.ndim == 2:
                    nchan = 1
                channels = None
            else:
                nchan = 2

            # model path
            if not os.path.exists(cpmodel_path):
                if not args.train:
                    error_message = 'ERROR: model path missing or incorrect - cannot train size model'
                    logger.critical(error_message)
                    raise ValueError(error_message)
                cpmodel_path = False
                logger.info('>>>> training from scratch')
                if args.diameter == 0:
                    rescale = False
                    logger.info(
                        '>>>> median diameter set to 0 => no rescaling during training'
                    )
                else:
                    rescale = True
                    szmean = args.diameter
            else:
                rescale = True
                args.diameter = szmean
                logger.info('>>>> pretrained model %s is being used' %
                            cpmodel_path)
                args.residual_on = 1
                args.style_on = 1
                args.concatenation = 0
            if rescale and args.train:
                logger.info(
                    '>>>> during training rescaling images to fixed diameter of %0.1f pixels'
                    % args.diameter)

            # initialize model
            if args.unet:
                model = core.UnetModel(device=device,
                                       pretrained_model=cpmodel_path,
                                       diam_mean=szmean,
                                       residual_on=args.residual_on,
                                       style_on=args.style_on,
                                       concatenation=args.concatenation,
                                       nclasses=args.nclasses,
                                       nchan=nchan)
            else:
                model = models.CellposeModel(device=device,
                                             torch=(not args.mxnet),
                                             pretrained_model=cpmodel_path,
                                             diam_mean=szmean,
                                             residual_on=args.residual_on,
                                             style_on=args.style_on,
                                             concatenation=args.concatenation,
                                             nchan=nchan)

            # train segmentation model
            if args.train:
                cpmodel_path = model.train(images,
                                           labels,
                                           train_files=image_names,
                                           test_data=test_images,
                                           test_labels=test_labels,
                                           test_files=image_names_test,
                                           learning_rate=args.learning_rate,
                                           channels=channels,
                                           save_path=os.path.realpath(
                                               args.dir),
                                           rescale=rescale,
                                           n_epochs=args.n_epochs,
                                           batch_size=args.batch_size)
                model.pretrained_model = cpmodel_path
                logger.info('>>>> model trained and saved to %s' %
                            cpmodel_path)

            # train size model
            if args.train_size:
                sz_model = models.SizeModel(cp_model=model, device=device)
                sz_model.train(images,
                               labels,
                               test_images,
                               test_labels,
                               channels=channels,
                               batch_size=args.batch_size)
                if test_images is not None:
                    predicted_diams, diams_style = sz_model.eval(
                        test_images, channels=channels)
                    if test_labels[0].ndim > 2:
                        tlabels = [lbl[0] for lbl in test_labels]
                    else:
                        tlabels = test_labels
                    ccs = np.corrcoef(
                        diams_style,
                        np.array([utils.diameters(lbl)[0]
                                  for lbl in tlabels]))[0, 1]
                    cc = np.corrcoef(
                        predicted_diams,
                        np.array([utils.diameters(lbl)[0]
                                  for lbl in tlabels]))[0, 1]
                    logger.info(
                        'style test correlation: %0.4f; final test correlation: %0.4f'
                        % (ccs, cc))
                    np.save(
                        os.path.join(
                            args.test_dir, '%s_predicted_diams.npy' %
                            os.path.split(cpmodel_path)[1]), {
                                'predicted_diams': predicted_diams,
                                'diams_style': diams_style
                            })
    def __init__(
        self,
        model_dir=None,
        type="cellpose",
        resume=True,
        pretrained_model=None,
        save_freq=None,
        use_gpu=True,
        diam_mean=30.0,
        residual_on=1,
        learning_rate=0.001,
        batch_size=2,
        channels=(1, 2),
        resample=True,
        flow_threshold=0.4,
        cellprob_threshold=0.0,
        interp=True,
        default_diameter=30,
        style_on=0,
        disable_mkldnn=True,
    ):
        assert type == "cellpose"
        assert model_dir is not None
        device, gpu = models.assign_device(True, use_gpu)
        self.learning_rate = learning_rate
        self.channels = channels
        self.batch_size = batch_size
        self.model_dir = model_dir
        self.resample = resample
        self.cellprob_threshold = cellprob_threshold
        self.flow_threshold = flow_threshold
        self.interp = interp
        self.default_diameter = default_diameter
        if save_freq is None:
            if gpu:
                self.save_freq = 2000
            else:
                self.save_freq = 300
        else:
            self.save_freq = save_freq
        self.model = models.CellposeModel(
            gpu=gpu,
            device=device,
            torch=True,
            pretrained_model=pretrained_model,
            diam_mean=diam_mean,
            residual_on=residual_on,
            style_on=style_on,
            concatenation=0,
            disable_mkldnn=disable_mkldnn,
        )
        os.makedirs(self.model_dir, exist_ok=True)
        if resume:
            resume_weights_path = os.path.join(self.model_dir, "snapshot")
            if os.path.exists(resume_weights_path):
                print("Resuming model from " + resume_weights_path)
                self.load(resume_weights_path)
                # disable pretrained model
                pretrained_model = False
            else:
                print("Skipping resume, snapshot does not exist")
        # load pretrained model weights if not specified
        if pretrained_model is None:
            cp_model_dir = Path.home().joinpath(".cellpose", "models")
            os.makedirs(cp_model_dir, exist_ok=True)
            weights_path = cp_model_dir / "cytotorch_0"
            if not weights_path.exists():
                urllib.request.urlretrieve(
                    "https://www.cellpose.org/models/cytotorch_0",
                    str(weights_path))
            if not (cp_model_dir / "size_cytotorch_0.npy").exists():
                urllib.request.urlretrieve(
                    "https://www.cellpose.org/models/size_cytotorch_0.npy",
                    str(cp_model_dir / "size_cytotorch_0.npy"),
                )

            print("loading pretrained cellpose model from " +
                  str(weights_path))
            if gpu:
                self.model.net.load_state_dict(torch.load(str(weights_path)),
                                               strict=False)
            else:
                self.model.net.load_state_dict(
                    torch.load(str(weights_path),
                               map_location=torch.device("cpu")),
                    strict=False,
                )
        self._iterations = 0
        momentum = 0.9
        weight_decay = 0.00001
        # Note: we are using Adam for adaptive learning rate which is different from the SDG used by cellpose
        # this support to make the training more robust to different settings
        self.model.optimizer = torch.optim.Adam(self.model.net.parameters(),
                                                lr=self.learning_rate,
                                                weight_decay=1e-5)
        self.model._set_criterion()
Esempio n. 3
0
def main():
    parser = argparse.ArgumentParser(description='cellpose parameters')

    # settings for CPU vs GPU
    hardware_args = parser.add_argument_group("hardware arguments")
    hardware_args.add_argument(
        '--use_gpu',
        action='store_true',
        help='use gpu if torch or mxnet with cuda installed')
    hardware_args.add_argument('--check_mkl',
                               action='store_true',
                               help='check if mkl working')
    hardware_args.add_argument(
        '--mkldnn',
        action='store_true',
        help='for mxnet, force MXNET_SUBGRAPH_BACKEND = "MKLDNN"')

    # settings for locating and formatting images
    input_img_args = parser.add_argument_group("input image arguments")
    input_img_args.add_argument(
        '--dir',
        default=[],
        type=str,
        help='folder containing data to run or train on.')
    input_img_args.add_argument(
        '--look_one_level_down',
        action='store_true',
        help='run processing on all subdirectories of current folder')
    input_img_args.add_argument('--mxnet',
                                action='store_true',
                                help='use mxnet')
    input_img_args.add_argument('--img_filter',
                                default=[],
                                type=str,
                                help='end string for images to run on')
    input_img_args.add_argument(
        '--channel_axis',
        default=None,
        type=int,
        help='axis of image which corresponds to image channels')
    input_img_args.add_argument(
        '--z_axis',
        default=None,
        type=int,
        help='axis of image which corresponds to Z dimension')
    input_img_args.add_argument(
        '--chan',
        default=0,
        type=int,
        help=
        'channel to segment; 0: GRAY, 1: RED, 2: GREEN, 3: BLUE. Default: %(default)s'
    )
    input_img_args.add_argument(
        '--chan2',
        default=0,
        type=int,
        help=
        'nuclear channel (if cyto, optional); 0: NONE, 1: RED, 2: GREEN, 3: BLUE. Default: %(default)s'
    )
    input_img_args.add_argument('--invert',
                                action='store_true',
                                help='invert grayscale channel')
    input_img_args.add_argument(
        '--all_channels',
        action='store_true',
        help=
        'use all channels in image if using own model and images with special channels'
    )

    # model settings
    model_args = parser.add_argument_group("model arguments")
    parser.add_argument('--pretrained_model',
                        required=False,
                        default='cyto',
                        type=str,
                        help='model to use')
    parser.add_argument(
        '--unet',
        required=False,
        default=0,
        type=int,
        help='run standard unet instead of cellpose flow output')
    model_args.add_argument(
        '--nclasses',
        default=3,
        type=int,
        help=
        'if running unet, choose 2 or 3; if training omni, choose 4; standard Cellpose uses 3'
    )

    # algorithm settings
    algorithm_args = parser.add_argument_group("algorithm arguments")
    parser.add_argument('--omni',
                        action='store_true',
                        help='Omnipose algorithm (disabled by default)')
    parser.add_argument(
        '--cluster',
        action='store_true',
        help=
        'DBSCAN clustering. Reduces oversegmentation of thin features (disabled by default).'
    )
    parser.add_argument(
        '--fast_mode',
        action='store_true',
        help=
        'make code run faster by turning off 4 network averaging and resampling'
    )
    parser.add_argument(
        '--no_resample',
        action='store_true',
        help=
        "disable dynamics on full image (makes algorithm faster for images with large diameters)"
    )
    parser.add_argument('--no_net_avg',
                        action='store_true',
                        help='make code run faster by only running 1 network')
    parser.add_argument(
        '--no_interp',
        action='store_true',
        help='do not interpolate when running dynamics (was default)')
    parser.add_argument(
        '--do_3D',
        action='store_true',
        help='process images as 3D stacks of images (nplanes x nchan x Ly x Lx'
    )
    parser.add_argument(
        '--diameter',
        required=False,
        default=30.,
        type=float,
        help='cell diameter, if 0 cellpose will estimate for each image')
    parser.add_argument(
        '--stitch_threshold',
        required=False,
        default=0.0,
        type=float,
        help=
        'compute masks in 2D then stitch together masks with IoU>0.9 across planes'
    )

    algorithm_args.add_argument(
        '--flow_threshold',
        default=0.4,
        type=float,
        help=
        'flow error threshold, 0 turns off this optional QC step. Default: %(default)s'
    )
    algorithm_args.add_argument(
        '--mask_threshold',
        default=0,
        type=float,
        help=
        'mask threshold, default is 0, decrease to find more and larger masks')

    parser.add_argument('--anisotropy',
                        required=False,
                        default=1.0,
                        type=float,
                        help='anisotropy of volume in 3D')
    parser.add_argument(
        '--diam_threshold',
        required=False,
        default=12.0,
        type=float,
        help=
        'cell diameter threshold for upscaling before mask rescontruction, default 12.'
    )
    parser.add_argument('--exclude_on_edges',
                        action='store_true',
                        help='discard masks which touch edges of image')

    # output settings
    output_args = parser.add_argument_group("output arguments")
    output_args.add_argument(
        '--save_png',
        action='store_true',
        help='save masks as png and outlines as text file for ImageJ')
    output_args.add_argument(
        '--save_tif',
        action='store_true',
        help='save masks as tif and outlines as text file for ImageJ')
    output_args.add_argument('--no_npy',
                             action='store_true',
                             help='suppress saving of npy')
    output_args.add_argument(
        '--savedir',
        default=None,
        type=str,
        help=
        'folder to which segmentation results will be saved (defaults to input image directory)'
    )
    output_args.add_argument(
        '--dir_above',
        action='store_true',
        help=
        'save output folders adjacent to image folder instead of inside it (off by default)'
    )
    output_args.add_argument(
        '--in_folders',
        action='store_true',
        help='flag to save output in folders (off by default)')
    output_args.add_argument(
        '--save_flows',
        action='store_true',
        help=
        'whether or not to save RGB images of flows when masks are saved (disabled by default)'
    )
    output_args.add_argument(
        '--save_outlines',
        action='store_true',
        help=
        'whether or not to save RGB outline images when masks are saved (disabled by default)'
    )
    output_args.add_argument(
        '--save_ncolor',
        action='store_true',
        help=
        'whether or not to save minimal "n-color" masks (disabled by default')
    output_args.add_argument(
        '--save_txt',
        action='store_true',
        help='flag to enable txt outlines for ImageJ (disabled by default)')

    # training settings
    training_args = parser.add_argument_group("training arguments")
    training_args.add_argument('--train',
                               action='store_true',
                               help='train network using images in dir')
    training_args.add_argument('--train_size',
                               action='store_true',
                               help='train size network at end of training')
    training_args.add_argument(
        '--mask_filter',
        default='_masks',
        type=str,
        help='end string for masks to run on. Default: %(default)s')
    training_args.add_argument('--test_dir',
                               default=[],
                               type=str,
                               help='folder containing test data (optional)')
    training_args.add_argument('--learning_rate',
                               default=0.2,
                               type=float,
                               help='learning rate. Default: %(default)s')
    training_args.add_argument('--n_epochs',
                               default=500,
                               type=int,
                               help='number of epochs. Default: %(default)s')
    training_args.add_argument('--batch_size',
                               default=8,
                               type=int,
                               help='batch size. Default: %(default)s')
    training_args.add_argument(
        '--min_train_masks',
        default=5,
        type=int,
        help=
        'minimum number of masks a training image must have to be used. Default: %(default)s'
    )
    training_args.add_argument('--residual_on',
                               default=1,
                               type=int,
                               help='use residual connections')
    training_args.add_argument('--style_on',
                               default=1,
                               type=int,
                               help='use style vector')
    training_args.add_argument(
        '--concatenation',
        default=0,
        type=int,
        help=
        'concatenate downsampled layers with upsampled layers (off by default which means they are added)'
    )
    training_args.add_argument(
        '--save_every',
        default=100,
        type=int,
        help='number of epochs to skip between saves. Default: %(default)s')
    training_args.add_argument(
        '--save_each',
        action='store_true',
        help=
        'save the model under a different filename per --save_every epoch for later comparsion'
    )

    # misc settings
    parser.add_argument(
        '--verbose',
        action='store_true',
        help=
        'flag to output extra information (e.g. diameter metrics) for debugging and fine-tuning parameters'
    )
    parser.add_argument(
        '--testing',
        action='store_true',
        help=
        'flag to suppress CLI user confirmation for saving output; for test scripts'
    )

    args = parser.parse_args()

    # handle mxnet option
    if args.check_mkl:
        mkl_enabled = models.check_mkl((not args.mxnet))
    else:
        mkl_enabled = True

    if not args.train and (mkl_enabled and args.mkldnn):
        os.environ["MXNET_SUBGRAPH_BACKEND"] = "MKLDNN"
    else:
        os.environ["MXNET_SUBGRAPH_BACKEND"] = ""

    if len(args.dir) == 0:
        if not GUI_ENABLED:
            print('GUI ERROR: %s' % GUI_ERROR)
            if GUI_IMPORT:
                print(
                    'GUI FAILED: GUI dependencies may not be installed, to install, run'
                )
                print('     pip install cellpose[gui]')
        else:
            gui.run()

    else:
        if args.verbose:
            from .io import logger_setup
            logger, log_file = logger_setup()
        else:
            print(
                '>>>> !NEW LOGGING SETUP! To see cellpose progress, set --verbose'
            )
            print('No --verbose => no progress or info printed')
            logger = logging.getLogger(__name__)

        use_gpu = False
        channels = [args.chan, args.chan2]

        # find images
        if len(args.img_filter) > 0:
            imf = args.img_filter
        else:
            imf = None

        # Check with user if they REALLY mean to run without saving anything
        if not (args.train or args.train_size):
            saving_something = args.save_png or args.save_tif or args.save_flows or args.save_ncolor or args.save_txt

        device, gpu = models.assign_device((not args.mxnet), args.use_gpu)

        #define available model names, right now we have three broad categories
        model_names = [
            'cyto', 'nuclei', 'bact', 'cyto2', 'bact_omni', 'cyto2_omni'
        ]
        builtin_model = np.any(
            [args.pretrained_model == s for s in model_names])
        cytoplasmic = 'cyto' in args.pretrained_model
        nuclear = 'nuclei' in args.pretrained_model
        bacterial = 'bact' in args.pretrained_model

        # force omni on for those models, but don't toggle it off if manually specified
        if 'omni' in args.pretrained_model:
            args.omni = True

        if args.cluster and 'sklearn' not in sys.modules:
            print('>>>> DBSCAN clustering requires scikit-learn.')
            confirm = confirm_prompt('Install scikit-learn?')
            if confirm:
                install('scikit-learn')
            else:
                print(
                    '>>>> scikit-learn not installed. DBSCAN clustering will be automatically disabled.'
                )

        omni = check_omni(
            args.omni
        )  # repeat the above check but factor it for use elsewhere
        if args.omni:
            print(
                '>>>> Omnipose enabled. See https://raw.githubusercontent.com/MouseLand/cellpose/master/cellpose/omnipose/license.txt for licensing details.'
            )

        if not args.train and not args.train_size:
            tic = time.time()
            if not builtin_model:
                cpmodel_path = args.pretrained_model
                if not os.path.exists(cpmodel_path):
                    logger.warning(
                        'model path does not exist, using cyto model')
                    args.pretrained_model = 'cyto'
                else:
                    logger.info(f'>>> running model {cpmodel_path}')

            image_names = io.get_image_files(
                args.dir,
                args.mask_filter,
                imf=imf,
                look_one_level_down=args.look_one_level_down)
            nimg = len(image_names)

            cstr0 = ['GRAY', 'RED', 'GREEN', 'BLUE']
            cstr1 = ['NONE', 'RED', 'GREEN', 'BLUE']
            logger.info(
                '>>>> running cellpose on %d images using chan_to_seg %s and chan (opt) %s'
                % (nimg, cstr0[channels[0]], cstr1[channels[1]]))
            if args.omni:
                logger.info(f'>>>> omni is ON, cluster is {args.cluster}')

            # handle built-in model exceptions; bacterial ones get no size model
            if builtin_model:
                if args.mxnet:
                    if args.pretrained_model == 'cyto2':
                        logger.warning(
                            'cyto2 model not available in mxnet, using cyto model'
                        )
                        args.pretrained_model = 'cyto'
                    if bacterial:
                        logger.warning(
                            'bacterial models not available in mxnet, using pytorch'
                        )
                        args.mxnet = False
                if not bacterial:
                    model = models.Cellpose(gpu=gpu,
                                            device=device,
                                            model_type=args.pretrained_model,
                                            torch=(not args.mxnet),
                                            omni=args.omni,
                                            net_avg=(not args.fast_mode
                                                     and not args.no_net_avg))
                else:
                    cpmodel_path = models.model_path(args.pretrained_model, 0,
                                                     True)
                    model = models.CellposeModel(gpu=gpu,
                                                 device=device,
                                                 pretrained_model=cpmodel_path,
                                                 torch=True,
                                                 nclasses=args.nclasses,
                                                 omni=args.omni,
                                                 net_avg=False)
            else:
                if args.all_channels:
                    channels = None
                model = models.CellposeModel(gpu=gpu,
                                             device=device,
                                             pretrained_model=cpmodel_path,
                                             torch=True,
                                             nclasses=args.nclasses,
                                             omni=args.omni,
                                             net_avg=False)

            # omni changes not implemented for mxnet. Full parity for cpu/gpu in pytorch.
            if args.omni and args.mxnet:
                logger.info('>>>> omni only implemented in pytorch.')
                confirm = confirm_prompt('Continue with omni set to false?')
                if not confirm:
                    exit()
                else:
                    logger.info('>>>> omni set to false.')
                    args.omni = False

            # For now, omni version is not compatible with 3D. WIP.
            if args.omni and args.do_3D:
                logger.info(
                    '>>>> omni not yet compatible with 3D segmentation.')
                confirm = confirm_prompt('Continue with omni set to false?')
                if not confirm:
                    exit()
                else:
                    logger.info('>>>> omni set to false.')
                    args.omni = False

            # omni model needs 4 classes. Would prefer a more elegant way to automaticaly update the flow fields
            # instead of users deleting them manually - a check on the number of channels, maybe, or just use
            # the yes/no prompt to ask the user if they want their flow fields in the given directory to be deleted.
            # would also need the look_one_level_down optionally toggled...
            if args.omni and args.train:
                logger.info('>>>> Training omni model. Setting nclasses to 4.')
                logger.info(
                    '>>>> Make sure your flow fields are deleted and re-computed.'
                )
                args.nclasses = 4

            # handle diameters
            if args.diameter == 0:
                if builtin_model:
                    diameter = None
                    logger.info('>>>> estimating diameter for each image')
                else:
                    logger.info(
                        '>>>> using user-specified model, no auto-diameter estimation available'
                    )
                    diameter = model.diam_mean
            else:
                diameter = args.diameter
                logger.info('>>>> using diameter %0.2f for all images' %
                            diameter)

            tqdm_out = utils.TqdmToLogger(logger, level=logging.INFO)

            for image_name in tqdm(image_names, file=tqdm_out):
                image = io.imread(image_name)
                out = model.eval(image,
                                 channels=channels,
                                 diameter=diameter,
                                 do_3D=args.do_3D,
                                 net_avg=(not args.fast_mode
                                          and not args.no_net_avg),
                                 augment=False,
                                 resample=(not args.no_resample
                                           and not args.fast_mode),
                                 flow_threshold=args.flow_threshold,
                                 mask_threshold=args.mask_threshold,
                                 diam_threshold=args.diam_threshold,
                                 invert=args.invert,
                                 batch_size=args.batch_size,
                                 interp=(not args.no_interp),
                                 cluster=args.cluster,
                                 channel_axis=args.channel_axis,
                                 z_axis=args.z_axis,
                                 omni=args.omni,
                                 anisotropy=args.anisotropy,
                                 verbose=args.verbose,
                                 model_loaded=True)
                masks, flows = out[:2]
                if len(out) > 3:
                    diams = out[-1]
                else:
                    diams = diameter
                if args.exclude_on_edges:
                    masks = utils.remove_edge_masks(masks)
                if not args.no_npy:
                    io.masks_flows_to_seg(image, masks, flows, diams,
                                          image_name, channels)
                if saving_something:
                    io.save_masks(image,
                                  masks,
                                  flows,
                                  image_name,
                                  png=args.save_png,
                                  tif=args.save_tif,
                                  save_flows=args.save_flows,
                                  save_outlines=args.save_outlines,
                                  save_ncolor=args.save_ncolor,
                                  dir_above=args.dir_above,
                                  savedir=args.savedir,
                                  save_txt=args.save_txt,
                                  in_folders=args.in_folders)
            logger.info('>>>> completed in %0.3f sec' % (time.time() - tic))
        else:
            if builtin_model:
                if args.mxnet and args.pretrained_model == 'cyto2':
                    logger.warning(
                        'cyto2 model not available in mxnet, using cyto model')
                    args.pretrained_model = 'cyto'
                cpmodel_path = models.model_path(args.pretrained_model, 0,
                                                 not args.mxnet)
                if cytoplasmic:
                    szmean = 30.
                elif nuclear:
                    szmean = 17.
                elif bacterial:
                    szmean = 0.  #bacterial models are not rescaled
            else:
                cpmodel_path = os.fspath(args.pretrained_model)
                szmean = 30.

            test_dir = None if len(args.test_dir) == 0 else args.test_dir
            output = io.load_train_test_data(args.dir, test_dir, imf,
                                             args.mask_filter, args.unet,
                                             args.look_one_level_down)
            images, labels, image_names, test_images, test_labels, image_names_test = output

            # training with all channels
            if args.all_channels:
                img = images[0]
                if img.ndim == 3:
                    nchan = min(img.shape)
                elif img.ndim == 2:
                    nchan = 1
                channels = None
            else:
                nchan = 2

            # model path
            if not os.path.exists(cpmodel_path):
                if not args.train:
                    error_message = 'ERROR: model path missing or incorrect - cannot train size model'
                    logger.critical(error_message)
                    raise ValueError(error_message)
                cpmodel_path = False
                logger.info('>>>> training from scratch')
                if args.diameter == 0:
                    rescale = False
                    logger.info(
                        '>>>> median diameter set to 0 => no rescaling during training'
                    )
                else:
                    rescale = True
                    szmean = args.diameter
            else:
                rescale = True
                args.diameter = szmean
                logger.info('>>>> pretrained model %s is being used' %
                            cpmodel_path)
                args.residual_on = 1
                args.style_on = 1
                args.concatenation = 0
            if rescale and args.train:
                logger.info(
                    '>>>> during training rescaling images to fixed diameter of %0.1f pixels'
                    % args.diameter)

            # initialize model
            if args.unet:
                model = core.UnetModel(device=device,
                                       pretrained_model=cpmodel_path,
                                       diam_mean=szmean,
                                       residual_on=args.residual_on,
                                       style_on=args.style_on,
                                       concatenation=args.concatenation,
                                       nclasses=args.nclasses,
                                       nchan=nchan)
            else:
                model = models.CellposeModel(device=device,
                                             torch=(not args.mxnet),
                                             pretrained_model=cpmodel_path,
                                             diam_mean=szmean,
                                             residual_on=args.residual_on,
                                             style_on=args.style_on,
                                             concatenation=args.concatenation,
                                             nclasses=args.nclasses,
                                             nchan=nchan,
                                             omni=args.omni)

            # train segmentation model
            if args.train:
                cpmodel_path = model.train(
                    images,
                    labels,
                    train_files=image_names,
                    test_data=test_images,
                    test_labels=test_labels,
                    test_files=image_names_test,
                    learning_rate=args.learning_rate,
                    channels=channels,
                    save_path=os.path.realpath(args.dir),
                    save_every=args.save_every,
                    save_each=args.save_each,
                    rescale=rescale,
                    n_epochs=args.n_epochs,
                    batch_size=args.batch_size,
                    min_train_masks=args.min_train_masks,
                    omni=args.omni)
                model.pretrained_model = cpmodel_path
                logger.info('>>>> model trained and saved to %s' %
                            cpmodel_path)

            # train size model
            if args.train_size:
                sz_model = models.SizeModel(cp_model=model, device=device)
                sz_model.train(images,
                               labels,
                               test_images,
                               test_labels,
                               channels=channels,
                               batch_size=args.batch_size)
                if test_images is not None:
                    predicted_diams, diams_style = sz_model.eval(
                        test_images, channels=channels)
                    if test_labels[0].ndim > 2:
                        tlabels = [lbl[0] for lbl in test_labels]
                    else:
                        tlabels = test_labels
                    ccs = np.corrcoef(
                        diams_style,
                        np.array([utils.diameters(lbl)[0]
                                  for lbl in tlabels]))[0, 1]
                    cc = np.corrcoef(
                        predicted_diams,
                        np.array([utils.diameters(lbl)[0]
                                  for lbl in tlabels]))[0, 1]
                    logger.info(
                        'style test correlation: %0.4f; final test correlation: %0.4f'
                        % (ccs, cc))
                    np.save(
                        os.path.join(
                            args.test_dir, '%s_predicted_diams.npy' %
                            os.path.split(cpmodel_path)[1]), {
                                'predicted_diams': predicted_diams,
                                'diams_style': diams_style
                            })