Ejemplo n.º 1
0
def test_timing(test_root, save_root):
    itest = 14
    test_data = io.imread(os.path.join(test_root, '%03d_img.tif' % itest))
    dat = np.load(os.path.join(test_root, 'predicted_diams.npy'),
                  allow_pickle=True).item()
    rescale = 30. / dat['predicted_diams'][itest]
    Ly, Lx = test_data.shape[1:]
    test_data = cv2.resize(np.transpose(test_data, (1, 2, 0)),
                           (int(Lx * rescale), int(Ly * rescale)))

    devices = [mx.gpu(), mx.cpu()]
    bsize = [256, 512, 1024]
    t100 = np.zeros((2, 3, 2))
    for d, device in enumerate(devices):
        model = models.CellposeModel(device=device, pretrained_model=None)
        for j in range(3):
            if j == 2:
                test_data = np.tile(test_data, (2, 2, 1))
            img = test_data[:bsize[j], :bsize[j]]
            imgs = [img for i in range(100)]
            for k in [0, 1]:
                tic = time.time()
                masks = model.eval(imgs,
                                   channels=[2, 1],
                                   rescale=1.0,
                                   net_avg=k)[0]
                print(masks[0].max())
                t100[d, j, k] = time.time() - tic
                print(t100[d, j, k])
Ejemplo n.º 2
0
def test_model_dir():
    import os, pathlib
    import numpy as np
    os.environ["CELLPOSE_LOCAL_MODELS_PATH"] = os.fspath(
        pathlib.Path.home().joinpath('.cellpose'))

    from cellpose import models
    model = models.CellposeModel(net_avg=False, pretrained_model='cyto')
    masks = model.eval(np.random.randn(224, 224))[0]
    assert masks.shape == (224, 224)
Ejemplo n.º 3
0
def test_cellpose_kfold_aug(data_root, save_root):
    """ test trained cellpose networks on all cyto images """
    device = mx.gpu()
    ntest = 68
    concatenation = [0]
    residual_on = [1]
    style_on = [1]
    channels = [2, 1]

    aps = np.zeros((9, 68, len(thresholds)))

    for j in range(9):
        train_root = os.path.join(data_root, 'train%d/' % j)
        model_root = os.path.join(train_root, 'models/')

        test_root = os.path.join(data_root, 'test%d/' % j)
        test_data = [
            io.imread(os.path.join(test_root, '%03d_img.tif' % i))
            for i in range(ntest)
        ]
        test_labels = [
            io.imread(os.path.join(test_root, '%03d_masks.tif' % i))
            for i in range(ntest)
        ]

        k = 0

        pretrained_models = get_pretrained_models(model_root, 0, 3,
                                                  residual_on[k], style_on[k],
                                                  concatenation[k])
        print(pretrained_models)

        cp_model = models.CellposeModel(device=device,
                                        pretrained_model=pretrained_models)

        dat = np.load(test_root + 'predicted_diams.npy',
                      allow_pickle=True).item()
        rescale = 30. / dat['predicted_diams']

        masks = cp_model.eval(test_data,
                              channels=channels,
                              rescale=rescale,
                              net_avg=True,
                              augment=True)[0]
        ap = metrics.average_precision(test_labels,
                                       masks,
                                       threshold=thresholds)[0]
        print(ap[:, [0, 5, 8]].mean(axis=0))
        aps[j] = ap

    return aps
Ejemplo n.º 4
0
def test_class_train(data_dir, image_names):
    train_dir = str(data_dir.joinpath('2D').joinpath('train'))
    model_dir = str(
        data_dir.joinpath('2D').joinpath('train').joinpath('models'))
    shutil.rmtree(model_dir, ignore_errors=True)
    output = io.load_train_test_data(train_dir, mask_filter='_cyto_masks')
    images, labels, image_names, test_images, test_labels, image_names_test = output
    model = models.CellposeModel(pretrained_model=None, diam_mean=30)
    cpmodel_path = model.train(images,
                               labels,
                               train_files=image_names,
                               test_data=test_images,
                               test_labels=test_labels,
                               test_files=image_names_test,
                               channels=[2, 1],
                               save_path=train_dir,
                               n_epochs=10)
    print('>>>> model trained and saved to %s' % cpmodel_path)
Ejemplo n.º 5
0
def test_nets_3D(stack, model_root, save_root, test_region=None):
    """ input 3D stack and test_region (where ground truth is labelled) """
    device = mx.gpu()

    model_archs = ['unet3']  #, 'unet2', 'cellpose']
    # found thresholds using ground truth
    cell_thresholds = [3., 0.25]
    boundary_thresholds = [0., 0.]
    for m, model_arch in enumerate(model_archs):
        if model_arch == 'cellpose':
            pretrained_models = [
                str(Path.home().joinpath('.cellpose/models/cyto_%d' % j))
                for j in range(4)
            ]
            model = models.CellposeModel(device=device,
                                         pretrained_model=pretrained_models)
            masks = model.eval(stack,
                               channels=[2, 1],
                               rescale=30. / 25.,
                               do_3D=True,
                               min_size=2000)[0]
        else:
            pretrained_models = get_pretrained_models(model_root,
                                                      unet=1,
                                                      nclass=int(
                                                          model_arch[-1]),
                                                      residual=0,
                                                      style=0,
                                                      concatenate=1)
            model = models.UnetModel(device=device,
                                     pretrained_model=pretrained_models)
            masks = model.eval(stack,
                               channels=[2, 1],
                               rescale=30. / 25.,
                               do_3D=True,
                               min_size=2000,
                               cell_threshold=cell_thresholds[m],
                               boundary_threshold=boundary_thresholds[m])[0]
        if test_region is not None:
            masks = masks[test_region]
            masks = utils.fill_holes_and_remove_small_masks(masks,
                                                            min_size=2000)
        np.save(os.path.join(save_root, '%s_3D_masks.npy' % model_arch), masks)
Ejemplo n.º 6
0
def test_cellpose(test_root,
                  save_root,
                  pretrained_models,
                  diam_file=None,
                  model_type='cyto'):
    """ test single cellpose net or 4 nets averaged """
    device = mx.gpu()
    ntest = len(glob(os.path.join(test_root, '*_img.tif')))
    if model_type[:4] != 'nuclei':
        channels = [2, 1]
    else:
        channels = [0, 0]

    test_data = [
        io.imread(os.path.join(test_root, '%03d_img.tif' % i))
        for i in range(ntest)
    ]

    # saved diameters
    if model_type != 'cyto_sp':
        if diam_file is None:
            dat = np.load(os.path.join(test_root, 'predicted_diams.npy'),
                          allow_pickle=True).item()
        else:
            dat = np.load(diam_file, allow_pickle=True).item()
        if model_type == 'cyto':
            rescale = 30. / dat['predicted_diams']
        else:
            rescale = 17. / dat['predicted_diams']
    else:
        rescale = np.ones(len(test_data))

    model = models.CellposeModel(device=device,
                                 pretrained_model=pretrained_models)
    masks = model.eval(test_data, channels=channels, rescale=rescale)[0]

    np.save(os.path.join(save_root, 'cellpose_%s_masks.npy' % model_type),
            masks)
Ejemplo n.º 7
0
            label_names = get_label_files(image_names, imf, args.mask_filter)
            nimg = len(image_names)
            labels = [skimage.io.imread(label_names[n]) for n in range(nimg)]
            if not os.path.exists(cpmodel_path):
                cpmodel_path = None
                print('>>>> training from scratch')
            else:
                print('>>>> training starting with pretrained_model %s' %
                      cpmodel_path)

            test_images, test_labels = None, None
            if len(args.test_dir) > 0:
                image_names_test = get_image_files(args.test_dir)
                label_names_test = get_label_files(image_names_test, imf,
                                                   args.mask_filter)
                nimg = len(image_names_test)
                test_images = [
                    skimage.io.imread(image_names_test[n]) for n in range(nimg)
                ]
                test_labels = [
                    skimage.io.imread(label_names_test[n]) for n in range(nimg)
                ]

            model = models.CellposeModel(device=device,
                                         pretrained_model=cpmodel_path)
            model.train(images,
                        labels,
                        test_images,
                        test_labels,
                        channels=channels,
                        save_path=os.path.realpath(args.dir))
    def __init__(
        self,
        model_dir=None,
        type="cellpose",
        resume=True,
        pretrained_model=None,
        save_freq=None,
        use_gpu=True,
        diam_mean=30.0,
        residual_on=1,
        learning_rate=0.001,
        batch_size=2,
        channels=(1, 2),
        resample=True,
        flow_threshold=0.4,
        cellprob_threshold=0.0,
        interp=True,
        default_diameter=30,
        style_on=0,
        disable_mkldnn=True,
    ):
        assert type == "cellpose"
        assert model_dir is not None
        device, gpu = models.assign_device(True, use_gpu)
        self.learning_rate = learning_rate
        self.channels = channels
        self.batch_size = batch_size
        self.model_dir = model_dir
        self.resample = resample
        self.cellprob_threshold = cellprob_threshold
        self.flow_threshold = flow_threshold
        self.interp = interp
        self.default_diameter = default_diameter
        if save_freq is None:
            if gpu:
                self.save_freq = 2000
            else:
                self.save_freq = 300
        else:
            self.save_freq = save_freq
        self.model = models.CellposeModel(
            gpu=gpu,
            device=device,
            torch=True,
            pretrained_model=pretrained_model,
            diam_mean=diam_mean,
            residual_on=residual_on,
            style_on=style_on,
            concatenation=0,
            disable_mkldnn=disable_mkldnn,
        )
        os.makedirs(self.model_dir, exist_ok=True)
        if resume:
            resume_weights_path = os.path.join(self.model_dir, "snapshot")
            if os.path.exists(resume_weights_path):
                print("Resuming model from " + resume_weights_path)
                self.load(resume_weights_path)
                # disable pretrained model
                pretrained_model = False
            else:
                print("Skipping resume, snapshot does not exist")
        # load pretrained model weights if not specified
        if pretrained_model is None:
            cp_model_dir = Path.home().joinpath(".cellpose", "models")
            os.makedirs(cp_model_dir, exist_ok=True)
            weights_path = cp_model_dir / "cytotorch_0"
            if not weights_path.exists():
                urllib.request.urlretrieve(
                    "https://www.cellpose.org/models/cytotorch_0",
                    str(weights_path))
            if not (cp_model_dir / "size_cytotorch_0.npy").exists():
                urllib.request.urlretrieve(
                    "https://www.cellpose.org/models/size_cytotorch_0.npy",
                    str(cp_model_dir / "size_cytotorch_0.npy"),
                )

            print("loading pretrained cellpose model from " +
                  str(weights_path))
            if gpu:
                self.model.net.load_state_dict(torch.load(str(weights_path)),
                                               strict=False)
            else:
                self.model.net.load_state_dict(
                    torch.load(str(weights_path),
                               map_location=torch.device("cpu")),
                    strict=False,
                )
        self._iterations = 0
        momentum = 0.9
        weight_decay = 0.00001
        # Note: we are using Adam for adaptive learning rate which is different from the SDG used by cellpose
        # this support to make the training more robust to different settings
        self.model.optimizer = torch.optim.Adam(self.model.net.parameters(),
                                                lr=self.learning_rate,
                                                weight_decay=1e-5)
        self.model._set_criterion()
    def run(self, workspace):
        if self.mode.value != MODE_CUSTOM:
            model = models.Cellpose(model_type='cyto' if self.mode.value
                                    == MODE_CELLS else 'nuclei',
                                    gpu=self.use_gpu.value)
        else:
            model_file = self.model_file_name.value
            model_directory = self.model_directory.get_absolute_path()
            model_path = os.path.join(model_directory, model_file)
            model = models.CellposeModel(pretrained_model=model_path,
                                         gpu=self.use_gpu.value)

        x_name = self.x_name.value
        y_name = self.y_name.value
        images = workspace.image_set
        x = images.get_image(x_name)
        dimensions = x.dimensions
        x_data = x.pixel_data

        if x.multichannel:
            raise ValueError(
                "Color images are not currently supported. Please provide greyscale images."
            )

        if self.mode.value != "Nuclei" and self.supply_nuclei.value:
            nuc_image = images.get_image(self.nuclei_image.value)
            # CellPose expects RGB, we'll have a blank red channel, cells in green and nuclei in blue.
            if x.volumetric:
                x_data = numpy.stack(
                    (numpy.zeros_like(x_data), x_data, nuc_image.pixel_data),
                    axis=1)
            else:
                x_data = numpy.stack(
                    (numpy.zeros_like(x_data), x_data, nuc_image.pixel_data),
                    axis=-1)
            channels = [2, 3]
        else:
            channels = [0, 0]

        diam = self.expected_diameter.value if self.expected_diameter.value > 0 else None

        try:
            y_data, flows, *_ = model.eval(
                x_data,
                channels=channels,
                diameter=diam,
                net_avg=self.use_averaging.value,
                do_3D=x.volumetric,
                flow_threshold=self.flow_threshold.value,
                cellprob_threshold=self.dist_threshold.value)
        finally:
            if self.use_gpu.value and model.torch:
                # Try to clear some GPU memory for other worker processes.
                try:
                    from torch import cuda
                    cuda.empty_cache()
                except Exception as e:
                    print(
                        f"Unable to clear GPU memory. You may need to restart CellProfiler to change models. {e}"
                    )

        y = Objects()
        y.segmented = y_data
        y.parent_image = x.parent_image
        objects = workspace.object_set
        objects.add_objects(y, y_name)

        if self.save_probabilities.value:
            # Flows come out sized relative to CellPose's inbuilt model size.
            # We need to slightly resize to match the original image.
            size_corrected = resize(flows[2], y_data.shape)
            prob_image = Image(
                size_corrected,
                parent_image=x.parent_image,
                convert=False,
                dimensions=len(size_corrected.shape),
            )

            workspace.image_set.add(self.probabilities_name.value, prob_image)

            if self.show_window:
                workspace.display_data.probabilities = size_corrected

        self.add_measurements(workspace)

        if self.show_window:
            if x.volumetric:
                # Can't show CellPose-accepted colour images in 3D
                workspace.display_data.x_data = x.pixel_data
            else:
                workspace.display_data.x_data = x_data
            workspace.display_data.y_data = y_data
            workspace.display_data.dimensions = dimensions
Ejemplo n.º 10
0
app.config['DROPZONE_UPLOAD_MULTIPLE'] = False
app.config['DROPZONE_ALLOWED_FILE_CUSTOM'] = True
app.config['DROPZONE_ALLOWED_FILE_TYPE'] = 'image/*'
app.config['DROPZONE_REDIRECT_VIEW'] = 'image_plot'
app.config['DROPZONE_MAX_FILE_SIZE'] = 10
app.config['DROPZONE_MAX_FILES'] = 1

# Uploads settings
app.config['UPLOADED_PHOTOS_DEST'] = '/tmp'  #os.getcwd() + '/uploads'
#'/tmp' #+ '/uploads'

photos = UploadSet('photos', IMAGES)
configure_uploads(app, photos)
patch_request_class(app)  # set maximum file size, default is 16MB

model = models.CellposeModel(device=mx.cpu(),
                             pretrained_model='static/models/cyto_0')


def url_to_image(file_url):
    resp = urllib.request.urlopen(file_url)
    img = np.asarray(bytearray(resp.read()), dtype="uint8")
    img = cv2.imdecode(img, cv2.IMREAD_COLOR)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img


def image_resize(img, resize=512):
    ny, nx = img.shape[:2]
    if np.array(img.shape).max() > resize:
        if ny > nx:
            nx = int(nx / ny * resize)
Ejemplo n.º 11
0
def main():

    parser = argparse.ArgumentParser(description='cellpose parameters')
    parser.add_argument('--check_mkl',
                        action='store_true',
                        help='check if mkl working')
    parser.add_argument(
        '--mkldnn',
        action='store_true',
        help='for mxnet, force MXNET_SUBGRAPH_BACKEND = "MKLDNN"')
    parser.add_argument('--train',
                        action='store_true',
                        help='train network using images in dir')
    parser.add_argument('--dir',
                        required=False,
                        default=[],
                        type=str,
                        help='folder containing data to run or train on')
    parser.add_argument('--look_one_level_down', action='store_true', help='')
    parser.add_argument('--mxnet', action='store_true', help='use mxnet')
    parser.add_argument('--img_filter',
                        required=False,
                        default=[],
                        type=str,
                        help='end string for images to run on')
    parser.add_argument('--use_gpu',
                        action='store_true',
                        help='use gpu if mxnet with cuda installed')
    parser.add_argument(
        '--fast_mode',
        action='store_true',
        help="make code run faster by turning off 4 network averaging")
    parser.add_argument(
        '--resample',
        action='store_true',
        help=
        "run dynamics on full image (slower for images with large diameters)")
    parser.add_argument(
        '--no_interp',
        action='store_true',
        help='do not interpolate when running dynamics (was default)')
    # settings for running cellpose
    parser.add_argument(
        '--do_3D',
        action='store_true',
        help='process images as 3D stacks of images (nplanes x nchan x Ly x Lx'
    )
    parser.add_argument('--pretrained_model',
                        required=False,
                        default='cyto',
                        type=str,
                        help='model to use')
    parser.add_argument(
        '--chan',
        required=False,
        default=0,
        type=int,
        help='channel to segment; 0: GRAY, 1: RED, 2: GREEN, 3: BLUE')
    parser.add_argument(
        '--chan2',
        required=False,
        default=0,
        type=int,
        help=
        'nuclear channel (if cyto, optional); 0: NONE, 1: RED, 2: GREEN, 3: BLUE'
    )
    parser.add_argument('--invert',
                        required=False,
                        action='store_true',
                        help='invert grayscale channel')
    parser.add_argument(
        '--all_channels',
        action='store_true',
        help=
        'use all channels in image if using own model and images with special channels'
    )
    parser.add_argument(
        '--diameter',
        required=False,
        default=30.,
        type=float,
        help='cell diameter, if 0 cellpose will estimate for each image')
    parser.add_argument(
        '--stitch_threshold',
        required=False,
        default=0.0,
        type=float,
        help=
        'compute masks in 2D then stitch together masks with IoU>0.9 across planes'
    )
    parser.add_argument(
        '--flow_threshold',
        required=False,
        default=0.4,
        type=float,
        help='flow error threshold, 0 turns off this optional QC step')
    parser.add_argument('--cellprob_threshold',
                        required=False,
                        default=0.0,
                        type=float,
                        help='cell probability threshold, centered at 0.0')
    parser.add_argument('--save_png',
                        action='store_true',
                        help='save masks as png')
    parser.add_argument('--save_outlines',
                        action='store_true',
                        help='save outlines as text file for ImageJ')
    parser.add_argument('--save_tif',
                        action='store_true',
                        help='save masks as tif')
    parser.add_argument('--no_npy',
                        action='store_true',
                        help='suppress saving of npy')
    parser.add_argument(
        '--channel_axis',
        required=False,
        default=None,
        type=int,
        help='axis of image which corresponds to image channels')
    parser.add_argument('--z_axis',
                        required=False,
                        default=None,
                        type=int,
                        help='axis of image which corresponds to Z dimension')
    parser.add_argument('--exclude_on_edges',
                        action='store_true',
                        help='discard masks which touch edges of image')
    parser.add_argument(
        '--unet',
        required=False,
        default=0,
        type=int,
        help='run standard unet instead of cellpose flow output')
    parser.add_argument(
        '--nclasses',
        required=False,
        default=3,
        type=int,
        help='if running unet, choose 2 or 3, otherwise not used')

    # settings for training
    parser.add_argument('--train_size',
                        action='store_true',
                        help='train size network at end of training')
    parser.add_argument('--mask_filter',
                        required=False,
                        default='_masks',
                        type=str,
                        help='end string for masks to run on')
    parser.add_argument('--test_dir',
                        required=False,
                        default=[],
                        type=str,
                        help='folder containing test data (optional)')
    parser.add_argument('--learning_rate',
                        required=False,
                        default=0.2,
                        type=float,
                        help='learning rate')
    parser.add_argument('--n_epochs',
                        required=False,
                        default=500,
                        type=int,
                        help='number of epochs')
    parser.add_argument('--batch_size',
                        required=False,
                        default=8,
                        type=int,
                        help='batch size')
    parser.add_argument('--residual_on',
                        required=False,
                        default=1,
                        type=int,
                        help='use residual connections')
    parser.add_argument('--style_on',
                        required=False,
                        default=1,
                        type=int,
                        help='use style vector')
    parser.add_argument(
        '--concatenation',
        required=False,
        default=0,
        type=int,
        help=
        'concatenate downsampled layers with upsampled layers (off by default which means they are added)'
    )

    args = parser.parse_args()

    if args.check_mkl:
        mkl_enabled = models.check_mkl((not args.mxnet))
    else:
        mkl_enabled = True

    if not args.train and (mkl_enabled and args.mkldnn):
        os.environ["MXNET_SUBGRAPH_BACKEND"] = "MKLDNN"
    else:
        os.environ["MXNET_SUBGRAPH_BACKEND"] = ""

    if len(args.dir) == 0:
        if not GUI_ENABLED:
            logger.critical('ERROR: %s' % GUI_ERROR)
            if GUI_IMPORT:
                logger.critical(
                    'GUI FAILED: GUI dependencies may not be installed, to install, run'
                )
                logger.critical('     pip install cellpose[gui]')
        else:
            gui.run()

    else:
        use_gpu = False
        channels = [args.chan, args.chan2]

        # find images
        if len(args.img_filter) > 0:
            imf = args.img_filter
        else:
            imf = None

        device, gpu = models.assign_device((not args.mxnet), args.use_gpu)

        if not args.train and not args.train_size:
            tic = time.time()
            if not (args.pretrained_model == 'cyto' or args.pretrained_model
                    == 'nuclei' or args.pretrained_model == 'cyto2'):
                cpmodel_path = args.pretrained_model
                if not os.path.exists(cpmodel_path):
                    logger.warning(
                        'model path does not exist, using cyto model')
                    args.pretrained_model = 'cyto'

            image_names = io.get_image_files(
                args.dir,
                args.mask_filter,
                imf=imf,
                look_one_level_down=args.look_one_level_down)
            nimg = len(image_names)

            cstr0 = ['GRAY', 'RED', 'GREEN', 'BLUE']
            cstr1 = ['NONE', 'RED', 'GREEN', 'BLUE']
            logger.info(
                '>>>> running cellpose on %d images using chan_to_seg %s and chan (opt) %s'
                % (nimg, cstr0[channels[0]], cstr1[channels[1]]))

            if args.pretrained_model == 'cyto' or args.pretrained_model == 'nuclei' or args.pretrained_model == 'cyto2':
                if args.mxnet and args.pretrained_model == 'cyto2':
                    logger.warning(
                        'cyto2 model not available in mxnet, using cyto model')
                    args.pretrained_model = 'cyto'
                model = models.Cellpose(gpu=gpu,
                                        device=device,
                                        model_type=args.pretrained_model,
                                        torch=(not args.mxnet))
            else:
                if args.all_channels:
                    channels = None
                model = models.CellposeModel(gpu=gpu,
                                             device=device,
                                             pretrained_model=cpmodel_path,
                                             torch=(not args.mxnet))

            if args.diameter == 0:
                if args.pretrained_model == 'cyto' or args.pretrained_model == 'nuclei' or args.pretrained_model == 'cyto2':
                    diameter = None
                    logger.info('>>>> estimating diameter for each image')
                else:
                    logger.info(
                        '>>>> using user-specified model, no auto-diameter estimation available'
                    )
                    diameter = model.diam_mean
            else:
                diameter = args.diameter
                logger.info('>>>> using diameter %0.2f for all images' %
                            diameter)

            tqdm_out = utils.TqdmToLogger(logger, level=logging.INFO)
            for image_name in tqdm(image_names, file=tqdm_out):
                image = io.imread(image_name)
                out = model.eval(image,
                                 channels=channels,
                                 diameter=diameter,
                                 do_3D=args.do_3D,
                                 net_avg=(not args.fast_mode),
                                 augment=False,
                                 resample=args.resample,
                                 flow_threshold=args.flow_threshold,
                                 cellprob_threshold=args.cellprob_threshold,
                                 invert=args.invert,
                                 batch_size=args.batch_size,
                                 interp=(not args.no_interp),
                                 channel_axis=args.channel_axis,
                                 z_axis=args.z_axis)
                masks, flows = out[:2]
                if len(out) > 3:
                    diams = out[-1]
                else:
                    diams = diameter
                if args.exclude_on_edges:
                    masks = utils.remove_edge_masks(masks)
                if not args.no_npy:
                    io.masks_flows_to_seg(image, masks, flows, diams,
                                          image_name, channels)
                if args.save_png or args.save_tif or args.save_outlines:
                    io.save_masks(image,
                                  masks,
                                  flows,
                                  image_name,
                                  png=args.save_png,
                                  tif=args.save_tif,
                                  outlines=args.save_outlines)
            logger.info('>>>> completed in %0.3f sec' % (time.time() - tic))
        else:
            if args.pretrained_model == 'cyto' or args.pretrained_model == 'nuclei' or args.pretrained_model == 'cyto2':
                if args.mxnet and args.pretrained_model == 'cyto2':
                    logger.warning(
                        'cyto2 model not available in mxnet, using cyto model')
                    args.pretrained_model = 'cyto'
                cpmodel_path = models.model_path(args.pretrained_model, 0,
                                                 not args.mxnet)
                if args.pretrained_model == 'cyto':
                    szmean = 30.
                else:
                    szmean = 17.
            else:
                cpmodel_path = os.fspath(args.pretrained_model)
                szmean = 30.

            test_dir = None if len(args.test_dir) == 0 else args.test_dir
            output = io.load_train_test_data(args.dir, test_dir, imf,
                                             args.mask_filter, args.unet,
                                             args.look_one_level_down)
            images, labels, image_names, test_images, test_labels, image_names_test = output

            # training with all channels
            if args.all_channels:
                img = images[0]
                if img.ndim == 3:
                    nchan = min(img.shape)
                elif img.ndim == 2:
                    nchan = 1
                channels = None
            else:
                nchan = 2

            # model path
            if not os.path.exists(cpmodel_path):
                if not args.train:
                    error_message = 'ERROR: model path missing or incorrect - cannot train size model'
                    logger.critical(error_message)
                    raise ValueError(error_message)
                cpmodel_path = False
                logger.info('>>>> training from scratch')
                if args.diameter == 0:
                    rescale = False
                    logger.info(
                        '>>>> median diameter set to 0 => no rescaling during training'
                    )
                else:
                    rescale = True
                    szmean = args.diameter
            else:
                rescale = True
                args.diameter = szmean
                logger.info('>>>> pretrained model %s is being used' %
                            cpmodel_path)
                args.residual_on = 1
                args.style_on = 1
                args.concatenation = 0
            if rescale and args.train:
                logger.info(
                    '>>>> during training rescaling images to fixed diameter of %0.1f pixels'
                    % args.diameter)

            # initialize model
            if args.unet:
                model = core.UnetModel(device=device,
                                       pretrained_model=cpmodel_path,
                                       diam_mean=szmean,
                                       residual_on=args.residual_on,
                                       style_on=args.style_on,
                                       concatenation=args.concatenation,
                                       nclasses=args.nclasses,
                                       nchan=nchan)
            else:
                model = models.CellposeModel(device=device,
                                             torch=(not args.mxnet),
                                             pretrained_model=cpmodel_path,
                                             diam_mean=szmean,
                                             residual_on=args.residual_on,
                                             style_on=args.style_on,
                                             concatenation=args.concatenation,
                                             nchan=nchan)

            # train segmentation model
            if args.train:
                cpmodel_path = model.train(images,
                                           labels,
                                           train_files=image_names,
                                           test_data=test_images,
                                           test_labels=test_labels,
                                           test_files=image_names_test,
                                           learning_rate=args.learning_rate,
                                           channels=channels,
                                           save_path=os.path.realpath(
                                               args.dir),
                                           rescale=rescale,
                                           n_epochs=args.n_epochs,
                                           batch_size=args.batch_size)
                model.pretrained_model = cpmodel_path
                logger.info('>>>> model trained and saved to %s' %
                            cpmodel_path)

            # train size model
            if args.train_size:
                sz_model = models.SizeModel(cp_model=model, device=device)
                sz_model.train(images,
                               labels,
                               test_images,
                               test_labels,
                               channels=channels,
                               batch_size=args.batch_size)
                if test_images is not None:
                    predicted_diams, diams_style = sz_model.eval(
                        test_images, channels=channels)
                    if test_labels[0].ndim > 2:
                        tlabels = [lbl[0] for lbl in test_labels]
                    else:
                        tlabels = test_labels
                    ccs = np.corrcoef(
                        diams_style,
                        np.array([utils.diameters(lbl)[0]
                                  for lbl in tlabels]))[0, 1]
                    cc = np.corrcoef(
                        predicted_diams,
                        np.array([utils.diameters(lbl)[0]
                                  for lbl in tlabels]))[0, 1]
                    logger.info(
                        'style test correlation: %0.4f; final test correlation: %0.4f'
                        % (ccs, cc))
                    np.save(
                        os.path.join(
                            args.test_dir, '%s_predicted_diams.npy' %
                            os.path.split(cpmodel_path)[1]), {
                                'predicted_diams': predicted_diams,
                                'diams_style': diams_style
                            })
Ejemplo n.º 12
0
def main():
    parser = argparse.ArgumentParser(description='cellpose parameters')

    # settings for CPU vs GPU
    hardware_args = parser.add_argument_group("hardware arguments")
    hardware_args.add_argument(
        '--use_gpu',
        action='store_true',
        help='use gpu if torch or mxnet with cuda installed')
    hardware_args.add_argument('--check_mkl',
                               action='store_true',
                               help='check if mkl working')
    hardware_args.add_argument(
        '--mkldnn',
        action='store_true',
        help='for mxnet, force MXNET_SUBGRAPH_BACKEND = "MKLDNN"')

    # settings for locating and formatting images
    input_img_args = parser.add_argument_group("input image arguments")
    input_img_args.add_argument(
        '--dir',
        default=[],
        type=str,
        help='folder containing data to run or train on.')
    input_img_args.add_argument(
        '--look_one_level_down',
        action='store_true',
        help='run processing on all subdirectories of current folder')
    input_img_args.add_argument('--mxnet',
                                action='store_true',
                                help='use mxnet')
    input_img_args.add_argument('--img_filter',
                                default=[],
                                type=str,
                                help='end string for images to run on')
    input_img_args.add_argument(
        '--channel_axis',
        default=None,
        type=int,
        help='axis of image which corresponds to image channels')
    input_img_args.add_argument(
        '--z_axis',
        default=None,
        type=int,
        help='axis of image which corresponds to Z dimension')
    input_img_args.add_argument(
        '--chan',
        default=0,
        type=int,
        help=
        'channel to segment; 0: GRAY, 1: RED, 2: GREEN, 3: BLUE. Default: %(default)s'
    )
    input_img_args.add_argument(
        '--chan2',
        default=0,
        type=int,
        help=
        'nuclear channel (if cyto, optional); 0: NONE, 1: RED, 2: GREEN, 3: BLUE. Default: %(default)s'
    )
    input_img_args.add_argument('--invert',
                                action='store_true',
                                help='invert grayscale channel')
    input_img_args.add_argument(
        '--all_channels',
        action='store_true',
        help=
        'use all channels in image if using own model and images with special channels'
    )

    # model settings
    model_args = parser.add_argument_group("model arguments")
    parser.add_argument('--pretrained_model',
                        required=False,
                        default='cyto',
                        type=str,
                        help='model to use')
    parser.add_argument(
        '--unet',
        required=False,
        default=0,
        type=int,
        help='run standard unet instead of cellpose flow output')
    model_args.add_argument(
        '--nclasses',
        default=3,
        type=int,
        help=
        'if running unet, choose 2 or 3; if training omni, choose 4; standard Cellpose uses 3'
    )

    # algorithm settings
    algorithm_args = parser.add_argument_group("algorithm arguments")
    parser.add_argument('--omni',
                        action='store_true',
                        help='Omnipose algorithm (disabled by default)')
    parser.add_argument(
        '--cluster',
        action='store_true',
        help=
        'DBSCAN clustering. Reduces oversegmentation of thin features (disabled by default).'
    )
    parser.add_argument(
        '--fast_mode',
        action='store_true',
        help=
        'make code run faster by turning off 4 network averaging and resampling'
    )
    parser.add_argument(
        '--no_resample',
        action='store_true',
        help=
        "disable dynamics on full image (makes algorithm faster for images with large diameters)"
    )
    parser.add_argument('--no_net_avg',
                        action='store_true',
                        help='make code run faster by only running 1 network')
    parser.add_argument(
        '--no_interp',
        action='store_true',
        help='do not interpolate when running dynamics (was default)')
    parser.add_argument(
        '--do_3D',
        action='store_true',
        help='process images as 3D stacks of images (nplanes x nchan x Ly x Lx'
    )
    parser.add_argument(
        '--diameter',
        required=False,
        default=30.,
        type=float,
        help='cell diameter, if 0 cellpose will estimate for each image')
    parser.add_argument(
        '--stitch_threshold',
        required=False,
        default=0.0,
        type=float,
        help=
        'compute masks in 2D then stitch together masks with IoU>0.9 across planes'
    )

    algorithm_args.add_argument(
        '--flow_threshold',
        default=0.4,
        type=float,
        help=
        'flow error threshold, 0 turns off this optional QC step. Default: %(default)s'
    )
    algorithm_args.add_argument(
        '--mask_threshold',
        default=0,
        type=float,
        help=
        'mask threshold, default is 0, decrease to find more and larger masks')

    parser.add_argument('--anisotropy',
                        required=False,
                        default=1.0,
                        type=float,
                        help='anisotropy of volume in 3D')
    parser.add_argument(
        '--diam_threshold',
        required=False,
        default=12.0,
        type=float,
        help=
        'cell diameter threshold for upscaling before mask rescontruction, default 12.'
    )
    parser.add_argument('--exclude_on_edges',
                        action='store_true',
                        help='discard masks which touch edges of image')

    # output settings
    output_args = parser.add_argument_group("output arguments")
    output_args.add_argument(
        '--save_png',
        action='store_true',
        help='save masks as png and outlines as text file for ImageJ')
    output_args.add_argument(
        '--save_tif',
        action='store_true',
        help='save masks as tif and outlines as text file for ImageJ')
    output_args.add_argument('--no_npy',
                             action='store_true',
                             help='suppress saving of npy')
    output_args.add_argument(
        '--savedir',
        default=None,
        type=str,
        help=
        'folder to which segmentation results will be saved (defaults to input image directory)'
    )
    output_args.add_argument(
        '--dir_above',
        action='store_true',
        help=
        'save output folders adjacent to image folder instead of inside it (off by default)'
    )
    output_args.add_argument(
        '--in_folders',
        action='store_true',
        help='flag to save output in folders (off by default)')
    output_args.add_argument(
        '--save_flows',
        action='store_true',
        help=
        'whether or not to save RGB images of flows when masks are saved (disabled by default)'
    )
    output_args.add_argument(
        '--save_outlines',
        action='store_true',
        help=
        'whether or not to save RGB outline images when masks are saved (disabled by default)'
    )
    output_args.add_argument(
        '--save_ncolor',
        action='store_true',
        help=
        'whether or not to save minimal "n-color" masks (disabled by default')
    output_args.add_argument(
        '--save_txt',
        action='store_true',
        help='flag to enable txt outlines for ImageJ (disabled by default)')

    # training settings
    training_args = parser.add_argument_group("training arguments")
    training_args.add_argument('--train',
                               action='store_true',
                               help='train network using images in dir')
    training_args.add_argument('--train_size',
                               action='store_true',
                               help='train size network at end of training')
    training_args.add_argument(
        '--mask_filter',
        default='_masks',
        type=str,
        help='end string for masks to run on. Default: %(default)s')
    training_args.add_argument('--test_dir',
                               default=[],
                               type=str,
                               help='folder containing test data (optional)')
    training_args.add_argument('--learning_rate',
                               default=0.2,
                               type=float,
                               help='learning rate. Default: %(default)s')
    training_args.add_argument('--n_epochs',
                               default=500,
                               type=int,
                               help='number of epochs. Default: %(default)s')
    training_args.add_argument('--batch_size',
                               default=8,
                               type=int,
                               help='batch size. Default: %(default)s')
    training_args.add_argument(
        '--min_train_masks',
        default=5,
        type=int,
        help=
        'minimum number of masks a training image must have to be used. Default: %(default)s'
    )
    training_args.add_argument('--residual_on',
                               default=1,
                               type=int,
                               help='use residual connections')
    training_args.add_argument('--style_on',
                               default=1,
                               type=int,
                               help='use style vector')
    training_args.add_argument(
        '--concatenation',
        default=0,
        type=int,
        help=
        'concatenate downsampled layers with upsampled layers (off by default which means they are added)'
    )
    training_args.add_argument(
        '--save_every',
        default=100,
        type=int,
        help='number of epochs to skip between saves. Default: %(default)s')
    training_args.add_argument(
        '--save_each',
        action='store_true',
        help=
        'save the model under a different filename per --save_every epoch for later comparsion'
    )

    # misc settings
    parser.add_argument(
        '--verbose',
        action='store_true',
        help=
        'flag to output extra information (e.g. diameter metrics) for debugging and fine-tuning parameters'
    )
    parser.add_argument(
        '--testing',
        action='store_true',
        help=
        'flag to suppress CLI user confirmation for saving output; for test scripts'
    )

    args = parser.parse_args()

    # handle mxnet option
    if args.check_mkl:
        mkl_enabled = models.check_mkl((not args.mxnet))
    else:
        mkl_enabled = True

    if not args.train and (mkl_enabled and args.mkldnn):
        os.environ["MXNET_SUBGRAPH_BACKEND"] = "MKLDNN"
    else:
        os.environ["MXNET_SUBGRAPH_BACKEND"] = ""

    if len(args.dir) == 0:
        if not GUI_ENABLED:
            print('GUI ERROR: %s' % GUI_ERROR)
            if GUI_IMPORT:
                print(
                    'GUI FAILED: GUI dependencies may not be installed, to install, run'
                )
                print('     pip install cellpose[gui]')
        else:
            gui.run()

    else:
        if args.verbose:
            from .io import logger_setup
            logger, log_file = logger_setup()
        else:
            print(
                '>>>> !NEW LOGGING SETUP! To see cellpose progress, set --verbose'
            )
            print('No --verbose => no progress or info printed')
            logger = logging.getLogger(__name__)

        use_gpu = False
        channels = [args.chan, args.chan2]

        # find images
        if len(args.img_filter) > 0:
            imf = args.img_filter
        else:
            imf = None

        # Check with user if they REALLY mean to run without saving anything
        if not (args.train or args.train_size):
            saving_something = args.save_png or args.save_tif or args.save_flows or args.save_ncolor or args.save_txt

        device, gpu = models.assign_device((not args.mxnet), args.use_gpu)

        #define available model names, right now we have three broad categories
        model_names = [
            'cyto', 'nuclei', 'bact', 'cyto2', 'bact_omni', 'cyto2_omni'
        ]
        builtin_model = np.any(
            [args.pretrained_model == s for s in model_names])
        cytoplasmic = 'cyto' in args.pretrained_model
        nuclear = 'nuclei' in args.pretrained_model
        bacterial = 'bact' in args.pretrained_model

        # force omni on for those models, but don't toggle it off if manually specified
        if 'omni' in args.pretrained_model:
            args.omni = True

        if args.cluster and 'sklearn' not in sys.modules:
            print('>>>> DBSCAN clustering requires scikit-learn.')
            confirm = confirm_prompt('Install scikit-learn?')
            if confirm:
                install('scikit-learn')
            else:
                print(
                    '>>>> scikit-learn not installed. DBSCAN clustering will be automatically disabled.'
                )

        omni = check_omni(
            args.omni
        )  # repeat the above check but factor it for use elsewhere
        if args.omni:
            print(
                '>>>> Omnipose enabled. See https://raw.githubusercontent.com/MouseLand/cellpose/master/cellpose/omnipose/license.txt for licensing details.'
            )

        if not args.train and not args.train_size:
            tic = time.time()
            if not builtin_model:
                cpmodel_path = args.pretrained_model
                if not os.path.exists(cpmodel_path):
                    logger.warning(
                        'model path does not exist, using cyto model')
                    args.pretrained_model = 'cyto'
                else:
                    logger.info(f'>>> running model {cpmodel_path}')

            image_names = io.get_image_files(
                args.dir,
                args.mask_filter,
                imf=imf,
                look_one_level_down=args.look_one_level_down)
            nimg = len(image_names)

            cstr0 = ['GRAY', 'RED', 'GREEN', 'BLUE']
            cstr1 = ['NONE', 'RED', 'GREEN', 'BLUE']
            logger.info(
                '>>>> running cellpose on %d images using chan_to_seg %s and chan (opt) %s'
                % (nimg, cstr0[channels[0]], cstr1[channels[1]]))
            if args.omni:
                logger.info(f'>>>> omni is ON, cluster is {args.cluster}')

            # handle built-in model exceptions; bacterial ones get no size model
            if builtin_model:
                if args.mxnet:
                    if args.pretrained_model == 'cyto2':
                        logger.warning(
                            'cyto2 model not available in mxnet, using cyto model'
                        )
                        args.pretrained_model = 'cyto'
                    if bacterial:
                        logger.warning(
                            'bacterial models not available in mxnet, using pytorch'
                        )
                        args.mxnet = False
                if not bacterial:
                    model = models.Cellpose(gpu=gpu,
                                            device=device,
                                            model_type=args.pretrained_model,
                                            torch=(not args.mxnet),
                                            omni=args.omni,
                                            net_avg=(not args.fast_mode
                                                     and not args.no_net_avg))
                else:
                    cpmodel_path = models.model_path(args.pretrained_model, 0,
                                                     True)
                    model = models.CellposeModel(gpu=gpu,
                                                 device=device,
                                                 pretrained_model=cpmodel_path,
                                                 torch=True,
                                                 nclasses=args.nclasses,
                                                 omni=args.omni,
                                                 net_avg=False)
            else:
                if args.all_channels:
                    channels = None
                model = models.CellposeModel(gpu=gpu,
                                             device=device,
                                             pretrained_model=cpmodel_path,
                                             torch=True,
                                             nclasses=args.nclasses,
                                             omni=args.omni,
                                             net_avg=False)

            # omni changes not implemented for mxnet. Full parity for cpu/gpu in pytorch.
            if args.omni and args.mxnet:
                logger.info('>>>> omni only implemented in pytorch.')
                confirm = confirm_prompt('Continue with omni set to false?')
                if not confirm:
                    exit()
                else:
                    logger.info('>>>> omni set to false.')
                    args.omni = False

            # For now, omni version is not compatible with 3D. WIP.
            if args.omni and args.do_3D:
                logger.info(
                    '>>>> omni not yet compatible with 3D segmentation.')
                confirm = confirm_prompt('Continue with omni set to false?')
                if not confirm:
                    exit()
                else:
                    logger.info('>>>> omni set to false.')
                    args.omni = False

            # omni model needs 4 classes. Would prefer a more elegant way to automaticaly update the flow fields
            # instead of users deleting them manually - a check on the number of channels, maybe, or just use
            # the yes/no prompt to ask the user if they want their flow fields in the given directory to be deleted.
            # would also need the look_one_level_down optionally toggled...
            if args.omni and args.train:
                logger.info('>>>> Training omni model. Setting nclasses to 4.')
                logger.info(
                    '>>>> Make sure your flow fields are deleted and re-computed.'
                )
                args.nclasses = 4

            # handle diameters
            if args.diameter == 0:
                if builtin_model:
                    diameter = None
                    logger.info('>>>> estimating diameter for each image')
                else:
                    logger.info(
                        '>>>> using user-specified model, no auto-diameter estimation available'
                    )
                    diameter = model.diam_mean
            else:
                diameter = args.diameter
                logger.info('>>>> using diameter %0.2f for all images' %
                            diameter)

            tqdm_out = utils.TqdmToLogger(logger, level=logging.INFO)

            for image_name in tqdm(image_names, file=tqdm_out):
                image = io.imread(image_name)
                out = model.eval(image,
                                 channels=channels,
                                 diameter=diameter,
                                 do_3D=args.do_3D,
                                 net_avg=(not args.fast_mode
                                          and not args.no_net_avg),
                                 augment=False,
                                 resample=(not args.no_resample
                                           and not args.fast_mode),
                                 flow_threshold=args.flow_threshold,
                                 mask_threshold=args.mask_threshold,
                                 diam_threshold=args.diam_threshold,
                                 invert=args.invert,
                                 batch_size=args.batch_size,
                                 interp=(not args.no_interp),
                                 cluster=args.cluster,
                                 channel_axis=args.channel_axis,
                                 z_axis=args.z_axis,
                                 omni=args.omni,
                                 anisotropy=args.anisotropy,
                                 verbose=args.verbose,
                                 model_loaded=True)
                masks, flows = out[:2]
                if len(out) > 3:
                    diams = out[-1]
                else:
                    diams = diameter
                if args.exclude_on_edges:
                    masks = utils.remove_edge_masks(masks)
                if not args.no_npy:
                    io.masks_flows_to_seg(image, masks, flows, diams,
                                          image_name, channels)
                if saving_something:
                    io.save_masks(image,
                                  masks,
                                  flows,
                                  image_name,
                                  png=args.save_png,
                                  tif=args.save_tif,
                                  save_flows=args.save_flows,
                                  save_outlines=args.save_outlines,
                                  save_ncolor=args.save_ncolor,
                                  dir_above=args.dir_above,
                                  savedir=args.savedir,
                                  save_txt=args.save_txt,
                                  in_folders=args.in_folders)
            logger.info('>>>> completed in %0.3f sec' % (time.time() - tic))
        else:
            if builtin_model:
                if args.mxnet and args.pretrained_model == 'cyto2':
                    logger.warning(
                        'cyto2 model not available in mxnet, using cyto model')
                    args.pretrained_model = 'cyto'
                cpmodel_path = models.model_path(args.pretrained_model, 0,
                                                 not args.mxnet)
                if cytoplasmic:
                    szmean = 30.
                elif nuclear:
                    szmean = 17.
                elif bacterial:
                    szmean = 0.  #bacterial models are not rescaled
            else:
                cpmodel_path = os.fspath(args.pretrained_model)
                szmean = 30.

            test_dir = None if len(args.test_dir) == 0 else args.test_dir
            output = io.load_train_test_data(args.dir, test_dir, imf,
                                             args.mask_filter, args.unet,
                                             args.look_one_level_down)
            images, labels, image_names, test_images, test_labels, image_names_test = output

            # training with all channels
            if args.all_channels:
                img = images[0]
                if img.ndim == 3:
                    nchan = min(img.shape)
                elif img.ndim == 2:
                    nchan = 1
                channels = None
            else:
                nchan = 2

            # model path
            if not os.path.exists(cpmodel_path):
                if not args.train:
                    error_message = 'ERROR: model path missing or incorrect - cannot train size model'
                    logger.critical(error_message)
                    raise ValueError(error_message)
                cpmodel_path = False
                logger.info('>>>> training from scratch')
                if args.diameter == 0:
                    rescale = False
                    logger.info(
                        '>>>> median diameter set to 0 => no rescaling during training'
                    )
                else:
                    rescale = True
                    szmean = args.diameter
            else:
                rescale = True
                args.diameter = szmean
                logger.info('>>>> pretrained model %s is being used' %
                            cpmodel_path)
                args.residual_on = 1
                args.style_on = 1
                args.concatenation = 0
            if rescale and args.train:
                logger.info(
                    '>>>> during training rescaling images to fixed diameter of %0.1f pixels'
                    % args.diameter)

            # initialize model
            if args.unet:
                model = core.UnetModel(device=device,
                                       pretrained_model=cpmodel_path,
                                       diam_mean=szmean,
                                       residual_on=args.residual_on,
                                       style_on=args.style_on,
                                       concatenation=args.concatenation,
                                       nclasses=args.nclasses,
                                       nchan=nchan)
            else:
                model = models.CellposeModel(device=device,
                                             torch=(not args.mxnet),
                                             pretrained_model=cpmodel_path,
                                             diam_mean=szmean,
                                             residual_on=args.residual_on,
                                             style_on=args.style_on,
                                             concatenation=args.concatenation,
                                             nclasses=args.nclasses,
                                             nchan=nchan,
                                             omni=args.omni)

            # train segmentation model
            if args.train:
                cpmodel_path = model.train(
                    images,
                    labels,
                    train_files=image_names,
                    test_data=test_images,
                    test_labels=test_labels,
                    test_files=image_names_test,
                    learning_rate=args.learning_rate,
                    channels=channels,
                    save_path=os.path.realpath(args.dir),
                    save_every=args.save_every,
                    save_each=args.save_each,
                    rescale=rescale,
                    n_epochs=args.n_epochs,
                    batch_size=args.batch_size,
                    min_train_masks=args.min_train_masks,
                    omni=args.omni)
                model.pretrained_model = cpmodel_path
                logger.info('>>>> model trained and saved to %s' %
                            cpmodel_path)

            # train size model
            if args.train_size:
                sz_model = models.SizeModel(cp_model=model, device=device)
                sz_model.train(images,
                               labels,
                               test_images,
                               test_labels,
                               channels=channels,
                               batch_size=args.batch_size)
                if test_images is not None:
                    predicted_diams, diams_style = sz_model.eval(
                        test_images, channels=channels)
                    if test_labels[0].ndim > 2:
                        tlabels = [lbl[0] for lbl in test_labels]
                    else:
                        tlabels = test_labels
                    ccs = np.corrcoef(
                        diams_style,
                        np.array([utils.diameters(lbl)[0]
                                  for lbl in tlabels]))[0, 1]
                    cc = np.corrcoef(
                        predicted_diams,
                        np.array([utils.diameters(lbl)[0]
                                  for lbl in tlabels]))[0, 1]
                    logger.info(
                        'style test correlation: %0.4f; final test correlation: %0.4f'
                        % (ccs, cc))
                    np.save(
                        os.path.join(
                            args.test_dir, '%s_predicted_diams.npy' %
                            os.path.split(cpmodel_path)[1]), {
                                'predicted_diams': predicted_diams,
                                'diams_style': diams_style
                            })
Ejemplo n.º 13
0
def test_cellpose_imports_without_error():
    import cellpose
    from cellpose import models, core
    model = models.CellposeModel()
    model = core.UnetModel()
Ejemplo n.º 14
0
def train_cellpose_nets(data_root):
    """ train networks on 9-folds of data (180 networks total) ... ~1 week on one GPU """
    # can also run on command line for GPU cluster
    # python -m cellpose --train --use_gpu --dir images_cyto/train"$7"/ --test_dir images_cyto/test"$7"/ --img_filter _img --pretrained_model None --chan 2 --chan2 1 --unet "$1" --nclasses "$2" --learning_rate "$3" --residual_on "$4" --style_on "$5" --concatenation "$6"
    device = mx.gpu()
    ntest = 68
    ntrain = 540
    concatenation = [0, 0, 0, 1, 1]
    residual_on = [1, 1, 0, 1, 0]
    style_on = [1, 0, 1, 1, 0]
    channels = [2, 1]

    for j in range(9):
        # load images
        train_root = os.path.join(data_root, 'train%d/' % j)
        train_data = [
            io.imread(os.path.join(train_root, '%03d_img.tif' % i))
            for i in range(ntrain)
        ]
        train_labels = [
            io.imread(os.path.join(train_root, '%03d_masks.tif' % i))
            for i in range(ntrain)
        ]
        train_flow_labels = [
            io.imread(os.path.join(train_root, '%03d_img_flows.tif' % i))
            for i in range(ntrain)
        ]
        train_labels = [
            np.concatenate(
                (train_labels[i][np.newaxis, :, :], train_flow_labels), axis=0)
            for i in range(ntrain)
        ]
        test_root = os.path.join(data_root, 'test%d/' % j)
        test_data = [
            io.imread(os.path.join(test_root, '%03d_img.tif' % i))
            for i in range(ntest)
        ]
        test_labels = [
            io.imread(os.path.join(test_root, '%03d_masks.tif' % i))
            for i in range(ntest)
        ]
        test_flow_labels = [
            io.imread(os.path.join(test_root, '%03d_img_flows.tif' % i))
            for i in range(ntest)
        ]
        test_labels = [
            np.concatenate(
                (test_labels[i][np.newaxis, :, :], test_flow_labels), axis=0)
            for i in range(ntest)
        ]

        # train networks
        for k in range(len(concatenation)):
            # 4 nets for each
            for l in range(4):
                model = models.CellposeModel(device=device,
                                             pretrained_model=None,
                                             diam_mean=30,
                                             residual_on=residual_on[k],
                                             style_on=style_on[k],
                                             concatenation=concatenation[k])
                model.train(images,
                            labels,
                            test_data=test_images,
                            test_labels=test_labels,
                            channels=channels,
                            rescale=True,
                            save_path=train_root)

                # train size network on default network once
                if k == 0 and l == 0:
                    sz_model = models.SizeModel(model, device=device)
                    sz_model.train(train_data,
                                   train_labels,
                                   test_data,
                                   test_labels,
                                   channels=channels)

                    predicted_diams, diams_style = sz_model.eval(
                        test_data, channels=channels)
                    tlabels = [lbl[0] for lbl in test_labels]
                    ccs = np.corrcoef(
                        diams_style,
                        np.array([utils.diameters(lbl)[0]
                                  for lbl in tlabels]))[0, 1]
                    cc = np.corrcoef(
                        predicted_diams,
                        np.array([utils.diameters(lbl)[0]
                                  for lbl in tlabels]))[0, 1]
                    print(
                        'style test correlation: %0.4f; final test correlation: %0.4f'
                        % (ccs, cc))
                    np.save(
                        os.path.join(test_root, 'predicted_diams.npy'), {
                            'predicted_diams': predicted_diams,
                            'diams_style': diams_style
                        })
Ejemplo n.º 15
0
            labels = [skimage.io.imread(label_names[n]) for n in range(nimg)]
            if not os.path.exists(cpmodel_path):
                cpmodel_path = None
                print('>>>> training from scratch')
            else:
                print('>>>> training starting with pretrained_model %s' %
                      cpmodel_path)

            test_images, test_labels = None, None
            if len(args.test_dir) > 0:
                image_names_test = get_image_files(args.test_dir)
                label_names_test = get_label_files(image_names_test, imf,
                                                   args.mask_filter)
                nimg = len(image_names_test)
                test_images = [
                    skimage.io.imread(image_names_test[n]) for n in range(nimg)
                ]
                test_labels = [
                    skimage.io.imread(label_names_test[n]) for n in range(nimg)
                ]
            print('>>>> %s model' % (['cellpose', 'unet'][args.unet]))
            model = models.CellposeModel(device=device,
                                         unet=args.unet,
                                         pretrained_model=cpmodel_path)
            model.train(images,
                        labels,
                        test_images,
                        test_labels,
                        learning_rate=args.learning_rate,
                        channels=channels,
                        save_path=os.path.realpath(args.dir))
Ejemplo n.º 16
0
input_file = open(in_file_dir, "r")
files = input_file.read().splitlines()
imgs = [skimage.io.imread(f) for f in files]
input_file.close()
nimg = len(imgs)

imgs_2D = imgs
# convert grayscale to rgb
for i in range(0, len(imgs)):
    if imgs[i].ndim < 3:
        imgs[i] = plot.image_to_rgb(imgs[i])

#model = models.Cellpose(gpu = False, model_type='cyto')
model = models.CellposeModel(
    gpu=False,
    pretrained_model=os.path.join(
        os.path.dirname(getsourcefile(lambda: 0)), "www",
        "cellpose_residual_on_style_on_concatenation_off_Manually_curated_Images_and_Masks_2021_01_08_21_36_17.184931"
    ))

channels_r = [1, 0]
channels_g = [2, 0]
channels_b = [3, 0]

# get masks for red, green, and blue channels
masks_r, flows, styles = model.eval(imgs_2D,
                                    diameter=None,
                                    flow_threshold=None,
                                    channels=channels_r)
masks_g, flows, styles = model.eval(imgs_2D,
                                    diameter=None,
                                    flow_threshold=None,