Example #1
0
def preload_images(specs, fov_id_list):
    '''This dictionary holds the first and last image
    for all channels in all FOVS. It is passed to the UI so that the
    figures can be populated much faster
    '''
    global p

    # Intialized the dicionary
    UI_images = {}

    for fov_id in fov_id_list:
        mm3.information("Preloading images for FOV {}.".format(fov_id))
        UI_images[fov_id] = {}
        for peak_id in specs[fov_id].keys():
            image_data = mm3.load_stack(fov_id,
                                        peak_id,
                                        color=p['phase_plane'])
            UI_images[fov_id][peak_id] = {
                'first': None,
                'last': None
            }  # init dictionary
            # phase image at t=0. Rescale intenstiy and also cut the size in half
            # old and new image size
            first_image = p['channel_picker']['first_image']
            img_size_old = image_data[first_image, :, :].shape
            img_size_new = (int(img_size_old[1] / 2), int(img_size_old[0] / 2))

            UI_images[fov_id][peak_id]['first'] = np.array(
                Image.fromarray(
                    image_data[first_image, :, :]).resize(img_size_new))
            # imresize(image_data[first_image,:,:], 0.5)
            last_image = p['channel_picker']['last_image']
            # phase image at end
            UI_images[fov_id][peak_id]['last'] = np.array(
                Image.fromarray(
                    image_data[last_image, :, :]).resize(img_size_new))
            # imresize(image_data[last_image,:,:], 0.5)

    return UI_images
Example #2
0
    # load specs file
    specs = mm3.load_specs()
    # print(specs) # for debugging

    # make list of FOVs to process (keys of channel_mask file)
    fov_id_list = sorted([fov_id for fov_id in specs.keys()])

    # remove fovs if the user specified so
    if user_spec_fovs:
        fov_id_list[:] = [fov for fov in fov_id_list if fov in user_spec_fovs]

    peaks_list = [
        peak_id for peak_id, val in specs[fov_id_list[0]].items() if val == 1
    ]
    img_height, img_width = mm3.load_stack(
        fov_id_list[0], peaks_list[0], color=p['phase_plane'])[0, :, :].shape

    # how many images total will we concatenate horizontally?
    img_count = 0
    for fov_id in fov_id_list:
        fov_peak_count = len(
            [peak_id for peak_id, val in specs[fov_id].items() if val == 1])
        img_count += fov_peak_count

    # placeholder array of img_height, and proper width to hold all pixels from this fov
    phase_arr = np.zeros((230, img_width * img_count), 'uint16')

    if namespace.cell_segs:
        # p['seg_dir'] = 'segmented'
        seg_arr = np.zeros((img_230height, img_width * img_count), 'uint16')
Example #3
0
def track_loop(fov_id,
               peak_id,
               params,
               tracks,
               model_dict,
               cell_number=6,
               data_number=9,
               img_file_name=None,
               seg_file_name=None,
               track_type='cells',
               max_cell_number=6):

    if img_file_name is None:

        if track_type == 'cells':
            seg_stack = mm3.load_stack(fov_id,
                                       peak_id,
                                       color=params['seg_img'])
            img_stack = mm3.load_stack(fov_id,
                                       peak_id,
                                       color=params['phase_plane'])
        elif track_type == 'foci':
            seg_stack = mm3.load_stack(fov_id,
                                       peak_id,
                                       color=params['seg_img'])
            img_stack = mm3.load_stack(fov_id,
                                       peak_id,
                                       color=params['foci']['foci_plane'])

    else:
        seg_stack = io.imread(seg_file_name)
        img_stack = io.imread(img_file_name)

    # run predictions for each tracking class
    # consider only the top six cells for a given trap when doing tracking
    frame_number = seg_stack.shape[0]

    # sometimes a phase contrast image is missed and has no signal.
    # This is a workaround for that problem
    no_signal_frames = []
    for k, img in enumerate(img_stack):
        if track_type == 'foci':
            if np.max(img) < 100:
                no_signal_frames.append(k)
        elif track_type == 'cells':
            # if the mean phase image signal is less than 200, add its index to list
            if np.mean(img) < 200:
                no_signal_frames.append(k)

    # loop through segmentation stack and replace frame from missed phase image
    #   with the prior frame.
    for k, label_img in enumerate(seg_stack):
        if k in no_signal_frames:
            seg_stack[k, ...] = seg_stack[k - 1, ...]

    if track_type == 'cells':
        regions_by_time = [
            measure.regionprops(label_image=img) for img in seg_stack
        ]
    elif track_type == 'foci':
        with open(p['cell_dir'] + '/all_cells.pkl', 'rb') as cell_file:
            Cells = pickle.load(cell_file)
        regions_by_time = []
        for i, img in enumerate(seg_stack):
            regs = measure.regionprops(label_image=img,
                                       intensity_image=img_stack[i, :, :])
            regs_sorted = mm3.sort_regions_in_list(regs)
            regions_by_time.append(regs_sorted)

    if track_type == 'cells':
        # have generator yield info for top six cells in all frames
        prediction_generator = mm3.PredictTrackDataGenerator(
            regions_by_time,
            batch_size=frame_number,
            dim=(cell_number, 5, data_number),
            track_type=track_type)
    elif track_type == 'foci':
        prediction_generator = mm3.PredictTrackDataGenerator(
            regions_by_time,
            batch_size=frame_number,
            dim=(cell_number, 5, data_number),
            track_type=track_type,
            img_stack=img_stack,
            images=True,
            img_dim=(5, 256, 32))
    cell_info = prediction_generator.__getitem__(0)

    predictions_dict = {}
    # run data through each classification model
    for key, mod in model_dict.items():

        # Run predictions and add to dictionary
        if key in [
                'zero_cell_model', 'one_cell_model', 'two_cell_model',
                'geq_three_cell_model'
        ]:
            continue

        mm3.information(
            'Predicting probability of {} events in FOV {}, trap {}.'.format(
                '_'.join(key.split('_')[:-1]), fov_id, peak_id))
        predictions_dict['{}_predictions'.format(key)] = mod.predict(cell_info)

    if track_type == 'cells':
        run_cells(
            tracks,
            peak_id,
            fov_id,
            params,
            predictions_dict,
            regions_by_time,
        )

    elif track_type == 'foci':
        pred_dict = {}
        (outbound1, outbound2, outbound3, outbound4, outbound5, outbound6,
         pred_dict['appear_model_predictions']
         ) = predictions_dict['all_model_predictions']
        # for this in predictions_dict['all_model_predictions']:
        #     print(this.shape)
        # pred_dict['appear_model_predictions'],pred_dict['disappear_model_predictions'],pred_dict['appear_model_predictions'] = predictions_dict['all_model_predictions']

        # take the -2nd element of each outbound array. the -1st is for "no focus", -2nd is for 'disappear, 0:6 are for migrate.
        pred_dict['disappear_model_predicitons'] = np.transpose(
            np.array([
                outbound1[:, -2], outbound2[:, -2], outbound3[:, -2],
                outbound4[:, -2], outbound5[:, -2], outbound6[:, -2]
            ]))

        # take the 0:6 elements of each outbound prediction result.
        pred_dict['migrate_model_predictions'] = np.concatenate([
            outbound1[:, :6],
            outbound2[:, :6],
            outbound3[:, :6],
            outbound4[:, :6],
            outbound5[:, :6],
            outbound6[:, :6],
        ],
                                                                axis=1)

        # print(pred_dict['migrate_model_predictions'].shape)

        run_foci(tracks,
                 peak_id,
                 fov_id,
                 params,
                 pred_dict,
                 regions_by_time,
                 Cells,
                 max_cell_number=max_cell_number,
                 appear_threshold=0.85)
Example #4
0
def fov_plot_channels(fov_id,
                      crosscorrs,
                      specs,
                      outputdir='.',
                      phase_plane='c1'):
    '''
    Creates a plot with the channels with guesses for empties and full channels.
    The plot is saved in PDF format.

    Parameters
    fov_id : str
        file name of the hdf5 file name in originals
    crosscorrs : dictionary
        dictionary for cross correlation values for all fovs.
    specs: dictionary
        dictionary for channal assignment (Analyze/Don't Analyze/Background).

    '''

    mm3.information("Plotting channels for FOV %d." % fov_id)

    # set up figure for user assited choosing
    n_peaks = len(specs[fov_id].keys())
    axw = 1
    axh = 4 * axw
    nrows = 3
    ncols = int(n_peaks)
    fig = plt.figure(num='none',
                     facecolor='w',
                     figsize=(ncols * axw, nrows * axh))
    gs = gridspec.GridSpec(nrows, ncols, wspace=0.5, hspace=0.1, top=0.90)

    # plot the peaks peak by peak using sorted list
    sorted_peaks = sorted([peak_id for peak_id in specs[fov_id].keys()])
    npeaks = len(sorted_peaks)

    for n, peak_id in enumerate(sorted_peaks):
        if crosscorrs:
            peak_xc = crosscorrs[fov_id][
                peak_id]  # get cross corr data from dict

        # load data for figure
        image_data = mm3.load_stack(fov_id, peak_id, color=phase_plane)

        first_img = rescale_intensity(
            image_data[0, :, :])  # phase image at t=0
        last_img = rescale_intensity(
            image_data[-1, :, :])  # phase image at end

        # append an axis handle to ax list while adding a subplot to the figure which has a
        axhi = fig.add_subplot(gs[0, n])
        axmid = fig.add_subplot(gs[1, n])
        axlo = fig.add_subplot(gs[2, n])

        # plot the first image in each channel in top row
        ax = axhi
        ax.imshow(first_img, cmap=plt.cm.gray, interpolation='nearest')
        ax.axis('off')
        ax.set_title(str(peak_id), fontsize=12)
        if n == 0:
            ax.set_ylabel("first time point")

        # plot middle row using last time point with highlighting for empty/full
        ax = axmid
        ax.axis('off')
        #ax.imshow(last_img,cmap=plt.cm.gray, interpolation='nearest')
        #H,W = last_img.shape
        #img = np.zeros((H,W,3))
        if specs[fov_id][peak_id] == 1:  # 1 means analyze, show green
            #img[:,:,1]=last_img
            cmap = plt.cm.Greens_r
        elif specs[fov_id][peak_id] == 0:  # 0 means reference, show blue
            #img[:,:,2]=last_img
            cmap = plt.cm.Blues_r
        else:  # otherwise show red, means don't analyze
            #img[:,:,0]=last_img
            cmap = plt.cm.Reds_r
        ax.imshow(last_img, cmap=cmap, interpolation='nearest')

        # format
        if n == 0:
            ax.set_ylabel("last time point")

        # finally plot the cross correlations a cross time
        ax = axlo
        if crosscorrs:  # don't try to plot if it's not there.
            ccs = peak_xc['ccs']  # list of cc values
            ax.plot(ccs, range(len(ccs)))
            ax.set_title('avg=%1.2f' % peak_xc['cc_avg'], fontsize=8)
        else:
            ax.plot(np.zeros(10), range(10))

        ax.get_xaxis().set_ticks([0.8, 0.9, 1.0])
        ax.set_xlim((0.8, 1))
        ax.tick_params('x', labelsize=8)
        if not n == 0:
            ax.set_yticks([])
        else:
            ax.set_ylabel("time index, CC on X")

    fig.suptitle("FOV {:d}".format(fov_id), fontsize=14)
    fileout = os.path.join(outputdir, 'fov_xy{:03d}.pdf'.format(fov_id))
    fig.savefig(fileout, bbox_inches='tight', pad_inches=0)
    plt.close('all')
    mm3.information("Written FOV {}'s channels in {}".format(fov_id, fileout))

    return specs
Example #5
0
def fov_CNN_plot_channels(fov_id,
                          predictionDict,
                          specs,
                          outputdir='.',
                          phase_plane='c1'):
    '''
    Creates a plot with the channels with guesses for empties and full channels.
    The plot is saved in PDF format.

    Parameters
    fov_id : str
        file name of the hdf5 file name in originals
    predictionDict : dictionary
        dictionary for cross correlation values for all fovs.
    specs: dictionary
        dictionary for channal assignment (Analyze/Don't Analyze/Background).

    '''

    mm3.information("Plotting channels for FOV %d." % fov_id)

    # set up figure for user assited choosing
    n_peaks = len(specs[fov_id].keys())
    axw = 1
    axh = 4 * axw
    nrows = 3
    ncols = int(n_peaks)
    fig = plt.figure(num='none',
                     facecolor='w',
                     figsize=(ncols * axw, nrows * axh))
    gs = gridspec.GridSpec(nrows, ncols, wspace=0.5, hspace=0.1, top=0.90)

    # plot the peaks peak by peak using sorted list
    sorted_peaks = sorted([peak_id for peak_id in specs[fov_id].keys()])
    npeaks = len(sorted_peaks)

    for n, peak_id in enumerate(sorted_peaks):
        if predictionDict:
            predictions = predictionDict[fov_id][
                peak_id]  # get predictions array

        # load data for figure
        image_data = mm3.load_stack(fov_id, peak_id, color=phase_plane)

        first_img = rescale_intensity(
            image_data[0, :, :])  # phase image at t=0
        last_img = rescale_intensity(
            image_data[-1, :, :])  # phase image at end

        # append an axis handle to ax list while adding a subplot to the figure which has a
        axhi = fig.add_subplot(gs[0, n])
        axmid = fig.add_subplot(gs[1, n])
        axlo = fig.add_subplot(gs[2, n])

        # plot the first image in each channel in top row
        ax = axhi
        ax.imshow(first_img, cmap=plt.cm.gray, interpolation='nearest')
        ax.axis('off')
        ax.set_title(str(peak_id), fontsize=12)
        if n == 0:
            ax.set_ylabel("first time point")

        # plot middle row using last time point with highlighting for empty/full
        ax = axmid
        ax.axis('off')
        #ax.imshow(last_img,cmap=plt.cm.gray, interpolation='nearest')
        #H,W = last_img.shape
        #img = np.zeros((H,W,3))
        if specs[fov_id][peak_id] == 1:  # 1 means analyze, show green
            #img[:,:,1]=last_img
            cmap = plt.cm.Greens_r
        elif specs[fov_id][peak_id] == 0:  # 0 means reference, show blue
            #img[:,:,2]=last_img
            cmap = plt.cm.Blues_r
        else:  # otherwise show red, means don't analyze
            #img[:,:,0]=last_img
            cmap = plt.cm.Reds_r
        ax.imshow(last_img, cmap=cmap, interpolation='nearest')

        # format
        if n == 0:
            ax.set_ylabel("last time point")

        # finally plot the prediction values as horizontal bar chart
        ax = axlo
        if predictionDict:
            ax.barh(range(len(predictions)), predictions)
            #ax.vlines(x=p['channel_picker']['channel_picking_threshold'], ymin=-1, ymax=5, linestyles='dashed',colors='red')
            ax.set_title('p', fontsize=8)
        else:
            ax.plot(np.zeros(10), range(10))

        ax.set_xlim((0, 1))  # set limits to (0,1)
        #ax.get_xaxis().set_ticks([])
        if not n == 0:
            ax.get_yaxis().set_ticks([])
        else:
            ax.set_yticklabels(
                labels=["", "Good", "Empty", "Out-of-focus", "Defective"])
            ax.set_ylabel("CNN prediction category")

    fig.suptitle("FOV {:d}".format(fov_id), fontsize=14)
    fileout = os.path.join(outputdir, 'fov_xy{:03d}.pdf'.format(fov_id))
    fig.savefig(fileout, bbox_inches='tight', pad_inches=0)
    plt.close('all')
    mm3.information("Written FOV {}'s channels in {}".format(fov_id, fileout))

    return specs
Example #6
0
    mm3.load_time_table()

    # This dictionary holds information for all cells
    # Cells = {}

    # do lineage creation per fov, per trap
    tracks = {}
    for i,fov_id in enumerate(fov_id_list):
        # tracks[fov_id] = {}
        # update will add the output from make_lineages_function, which is a
        # dict of Cell entries, into Cells
        ana_peak_ids = [peak_id for peak_id in specs[fov_id].keys() if specs[fov_id][peak_id] == 1]
        # ana_peak_ids = [9,13,15,19,25,33,36,37,38,39] # was used for debugging
        for j,peak_id in enumerate(ana_peak_ids):

            seg_stack = mm3.load_stack(fov_id, peak_id, color=p['seg_img'])
            # run predictions for each tracking class
            # consider only the top six cells for a given trap when doing tracking
            cell_number = 6
            frame_number = seg_stack.shape[0]
            # get region properties
            regions_by_time = [measure.regionprops(label_image=img) for img in seg_stack]

            # have generator yield info for top six cells in all frames
            prediction_generator = mm3.PredictTrackDataGenerator(regions_by_time, batch_size=frame_number, dim=(cell_number,5,9))
            cell_info = prediction_generator.__getitem__(0)

            predictions_dict = {}
            # run data through each classification model
            for key,mod in model_dict.items():
Example #7
0
def track_loop(fov_id,
               peak_id,
               params,
               tracks,
               model_dict,
               cell_number=6,
               phase_file_name=None,
               seg_file_name=None):

    if phase_file_name is None:

        seg_stack = mm3.load_stack(fov_id, peak_id, color=params['seg_img'])
        phase_stack = mm3.load_stack(fov_id,
                                     peak_id,
                                     color=params['phase_plane'])

    else:

        seg_stack = io.imread(seg_file_name)
        phase_stack = io.imread(phase_file_name)

    # run predictions for each tracking class
    # consider only the top six cells for a given trap when doing tracking
    frame_number = seg_stack.shape[0]

    # sometimes a phase contrast image is missed and has no signal.
    # This is a workaround for that problem
    no_signal_frames = []
    for k, img in enumerate(phase_stack):
        # if the mean phase image signal is less than 200, add its index to list
        if np.mean(img) < 200:
            no_signal_frames.append(k)

    # loop through segmentation stack and replace frame from missed phase image
    #   with the prior frame.
    for k, label_img in enumerate(seg_stack):
        if k in no_signal_frames:
            seg_stack[k, ...] = seg_stack[k - 1, ...]

    regions_by_time = [
        measure.regionprops(label_image=img) for img in seg_stack
    ]

    # have generator yield info for top six cells in all frames
    prediction_generator = mm3.PredictTrackDataGenerator(
        regions_by_time, batch_size=frame_number, dim=(cell_number, 5, 9))
    cell_info = prediction_generator.__getitem__(0)

    predictions_dict = {}
    # run data through each classification model
    for key, mod in model_dict.items():

        # Run predictions and add to dictionary
        if key in [
                'zero_cell_model', 'one_cell_model', 'two_cell_model',
                'geq_three_cell_model'
        ]:
            continue

        mm3.information(
            'Predicting probability of {} events in FOV {}, trap {}.'.format(
                '_'.join(key.split('_')[:-1]), fov_id, peak_id))
        predictions_dict['{}_predictions'.format(key)] = mod.predict(cell_info)

    G, graph_df = mm3.initialize_track_graph(
        peak_id=peak_id,
        fov_id=fov_id,
        experiment_name=params['experiment_name'],
        predictions_dict=predictions_dict,
        regions_by_time=regions_by_time,
        born_threshold=0.85,
        appear_threshold=0.85)

    tracks.update(mm3.create_lineages_from_graph(G, graph_df, fov_id, peak_id))