예제 #1
0
def test_load():
    d = {'date': '20190604', 'name': 'example_data_6frames'}

    test_dir, _ = os.path.split(__file__)
    test_dir = os.path.join(test_dir, 'test_raw_data')
    processed_data_dir = test_dir

    f = FrameProcessor(raw_data_dir=test_dir,
                       processed_data_dir=processed_data_dir,
                       dataset=d)

    stack = iio.load_raw_data(f._data_path, sub_seq=range(3), print_freq=1)
    assert (stack.shape[0] == 3)
    assert (stack.shape[1] == 791)
    assert (stack.shape[2] == 1607)
    assert (stack[0, 0, 0] == 5309)

    stack = iio.load_raw_data(f._data_path, print_freq=1)
    assert (stack.shape[0] == 6)
    assert (stack.shape[1] == 791)
    assert (stack.shape[2] == 1607)
    assert (stack[0, 0, 0] == 5309)
예제 #2
0
    def select_crop_rois(self):
        """
        Provides a GUI interface for selecting a provided
        number of square regions of interest (ROIs).
        Saves the coordinates for cropping to a cropInfo.npz
        file for each ROI, as well as an overall file
        allCropCoords.npz for subsequent image plotting.
        Also saves out an image of the ROIs overlaid on a mean image.
        :return: crop_coords: Crop coordinates for each ROI.
        """

        # Load up the raw data from individual .tif files
        print('Loading: ', self._data_path)
        stack = iio.load_raw_data(
            self._data_path, sub_seq=range(5), print_freq=1)

        # Manually select ROIs
        status = 'not done'
        mean_img = np.log10(np.mean(stack, 0)+1)
        # loc_names = [str(i) for i in range(self.nroi)]
        loc_names = self.vid_names
        output_folder = self._out_path
        while status:
            slices = self.__get_roi(mean_img, self.vid_names, output_folder)
            if (sys.version_info > (3, 0)):
                status = input(
                    'Press enter if ROIs look good. Otherwise say no:')
            else:
                status = raw_input(
                    'Press enter if ROIs look good. Otherwise say no:')
        np.savez(self._roi_coords_file, coords=slices, loc_names=loc_names)

        fig = plt.figure()
        for idx, slice_ in enumerate(slices):
            ax = fig.add_subplot(1, self.nroi, idx + 1)
            ax.imshow(mean_img[slice_[0], slice_[1]])
            # plt.title(loc_names[idx] + ' ' + self.vid_names[idx])
            plt.title(self.vid_names[idx])
        plt.savefig(self._out_path + '/roi_summary_indiv.pdf')
예제 #3
0
    def atlas_align(self):
        """
        Provides a limited GUI for aligning image to atlas, as well
        as selecting keypoints to align the two sub-images.
        Saves out the results to directories
        specified in the FrameProcessor constructor.
        # :param atlas_path: Full path to atlas_top_projection.mat
        #                    file containing
        #                    output from process_atlas_script.m
        :return: Saves out keypoints.npz containing aligned atlas and keypoints
        """
        print("Atlas alignment.")
        if (sys.version_info > (3, 0)):
            text = input(
                'Press enter to begin quick atlas alignment.' +
                ' Do not worry about precision, this can be refined later: ')
        else:
            text = raw_input(
                'Press enter to begin quick atlas alignment.' +
                'Do not worry about precision, this can be refined later: ')

        if os.path.isfile(self._roi_coords_file):
            coord_file = np.load(self._roi_coords_file)
        else:
            raise FileNotFoundError(
                'ROI coordinates file has not been generated.')

        for name in self.vid_names:
            plt.close('all')
            roi_dir = os.path.join(self._out_path, str(name))
            roi_path = os.path.join(roi_dir, str(name) + '.tif')
            vid = iio.load_raw_data(roi_dir, sub_seq=range(1), print_freq=100)
            img = vid

            keypoint_positions = ['Anterior midline',
                                  'Posterior midline']

            # Load atlas.
            atlas, annotations, atlas_outline = reg.load_atlas()
            do_manual_atlas_keypoint = False
            if do_manual_atlas_keypoint:
                atlas_coords = []
                plt.figure(figsize=(30, 30))
                plt.imshow(atlas_outline)
                for t in keypoint_positions:
                    plt.title('Click on ' + t)
                    c = plt.ginput(n=1)
                    plt.plot(c[0][0], c[0][1], 'ro')
                    atlas_coords.extend(c)
            else:
                atlas_coords = [(98, 227),
                                (348, 227)]
                plt.figure(figsize=(30, 30))
                plt.imshow(atlas_outline)
                for cc in atlas_coords:
                    plt.plot(cc[0], cc[1], 'ro')

            # Convert selected keypoints to array.
            atlas_coords_array = np.zeros((len(atlas_coords), 2))
            for ci, c in enumerate(atlas_coords):
                atlas_coords_array[ci, 0] = np.round(c[0])
                atlas_coords_array[ci, 1] = np.round(c[1])

            while 1:
                plt.figure()
                plt.imshow(img)
                img_coords = []
                do_manual_patch_keypoint = True
                if do_manual_patch_keypoint:
                    plt.figure(figsize=(30, 30))
                    plt.imshow(img)
                    for t in keypoint_positions:
                        plt.title('Click on ' + t)
                        c = plt.ginput(n=1)
                        plt.plot(c[0][0], c[0][1], 'ro')
                        img_coords.extend(c)
                else:
                    # These numbers are just for fast debugging.
                    img_coords = [(26, 297),
                                  (430, 314)]
                    plt.figure(figsize=(30, 30))
                    plt.imshow(img)
                    for cc in img_coords:
                        plt.plot(cc[0], cc[1], 'ro')

                plt.close('all')

                # Convert selected keypoints to array.
                img_coords_array = np.zeros((len(img_coords), 2))
                for ci, c in enumerate(img_coords):
                    img_coords_array[ci, 0] = np.round(c[0])
                    img_coords_array[ci, 1] = np.round(c[1])

                aao, ai, tf = reg.align_atlas_to_image(
                    atlas_outline, img, atlas_coords_array[0:2, :],
                    img_coords_array[0:2, :], do_debug=False)
                aligned_atlas_outline = aao
                aligned_img = ai
                tform = tf

                # Overlay atlas on image for checking that things look good.
                overlay = reg.overlay_atlas_outline(aligned_atlas_outline, img)
                plt.figure(figsize=(20, 20))
                plt.imshow(overlay)
                plt.title('Check that things look good,' +
                          'and close this window manually.')
                plt.show()

                if (sys.version_info > (3, 0)):
                    text = input('Look good? [y] or [n]')
                else:
                    text = raw_input('Look good? [y] or [n]')

                print(text)
                if text == 'y':
                    break

            # Save out selections
            keypoints_dir = os.path.join(self._keypoints_file, name,
                                         str(name) + '_source_extraction')
            save_fname = os.path.join(keypoints_dir, 'keypoints.npz')
            print('Saving keypoints and aligned atlas to: ' + save_fname)
            np.savez(save_fname,
                     coords=img_coords_array,
                     atlas_coords=atlas_coords,
                     atlas=atlas,
                     img=aligned_img,
                     aligned_atlas_outline=aligned_atlas_outline)

            plt.figure()
            plt.imshow(img)
            for ci, c in enumerate(img_coords):
                print(c)
                plt.plot(c[0], c[1], 'ro')
            plt.savefig(os.path.join(keypoints_dir, 'keypoints.png'))

            plt.figure()
            plt.imshow(overlay)
            plt.savefig(os.path.join(keypoints_dir, 'overlay.png'))
예제 #4
0
    def get_alignment_keypoints(self):
        """
        This function is DEPRECATED - you should use atlas_align instead.
        Provides GUI for selecting alignment
        keypoints for an image.
        Saves out the results to directories
        specified in the FrameProcessor constructor.
        :return: None
        """
        raise('Use atlas_align() instead of this get_alignment_keypoints().')

        if os.path.isfile(self._roi_coords_file):
            coord_file = np.load(self._roi_coords_file)
        else:
            raise FileNotFoundError(
                'ROI coordinates file has not been generated.')
        slices = coord_file['coords']
        # loc_names = coord_file['loc_names']

        # for name in loc_names:
        for name in self.vid_names:
            roi_dir = os.path.join(self._out_path, str(name))
            roi_path = os.path.join(roi_dir, str(name) + '.tif')

            vid = iio.load_raw_data(roi_dir, sub_seq=range(10), print_freq=100)

            patch_coords = []

            print(vid.shape)
            mean_img = np.squeeze(np.mean(vid, axis=0))
            plt.figure(figsize=(30, 30))
            plt.imshow(mean_img)
            plt.title('Click on anterior midline')
            c = plt.ginput(n=1)
            plt.plot(c[0][0], c[0][1], 'ro')
            patch_coords.extend(c)

            plt.title('Click on posterior midline (lambda)')
            c = plt.ginput(n=1)
            plt.plot(c[0][0], c[0][1], 'ro')
            patch_coords.extend(c)

            plt.title('Click on right anterior/lateral corner')
            c = plt.ginput(n=1)
            plt.plot(c[0][0], c[0][1], 'ro')
            patch_coords.extend(c)

            plt.title('Click on left anterior/lateral corner')
            c = plt.ginput(n=1)
            plt.plot(c[0][0], c[0][1], 'ro')
            patch_coords.extend(c)

            plt.close()

            # Save out selections
            patch_coords_array = np.zeros((len(patch_coords), 2))
            for ci, c in enumerate(patch_coords):
                patch_coords_array[ci, 0] = np.round(c[0])
                patch_coords_array[ci, 1] = np.round(c[1])

            print(patch_coords_array)
            keypoints_dir = os.path.join(self._keypoints_file, name,
                                         str(name)+'_source_extraction')
            save_fname = os.path.join(keypoints_dir, 'keypoints.npz')
            np.savez(save_fname, coords=patch_coords_array)

            plt.figure()
            plt.imshow(mean_img)
            for ci, c in enumerate(patch_coords):
                print(c)
                plt.plot(c[0], c[1], 'ro')
            plt.savefig(os.path.join(keypoints_dir, 'keypoints.png'))
예제 #5
0
    def plot_motion(self):
        """
        Plots a metric of motion across the time series.
        Uses spatial correlation of a set of cropped regions.
        Uses first frame as template, to which subsequent frames are compared.
        :return:
        """
        for name in ['top']:
            roi_dir = os.path.join(self._out_path, str(name))
            roi_path = os.path.join(roi_dir, str(name) + '.tif')

            try:
                ff = np.load(self._led_frame_file)
                led_frame_inds = ff['inds']
                sub_seq = None
            except (FileNotFoundError, ValueError):
                sub_seq = None

            vid = iio.load_raw_data(roi_dir, sub_seq=sub_seq, print_freq=100)

            # Crop out subregions.
            crop1 = vid[:, 200:300, 200:300]
            crop2 = vid[:, 200:300, 400:500]
            crop3 = vid[:, 400:500, 200:300]
            crop4 = vid[:, 400:500, 400:500]
            crops = (crop1, crop2, crop3, crop4)

            shifts = np.zeros((len(crops), vid.shape[0]))

            for crop_ind, crop in enumerate(crops):
                print('Crop #{}'.format(crop_ind))
                shiftx, shifty, template, target = self.get_motion(crop)

                shift = np.sqrt((shiftx**2)+(shifty**2))
                shifts[crop_ind, :] = shift[:, 0]

                plt.figure()
                plt.plot(shift)
                plt.ylabel('Pixel shift')
                plt.xlabel('Time [frames]')
                plt.title('Crop {}'.format(crop_ind))
                plt.savefig(
                    self._out_path + '/shift_crop{}.pdf'.format(crop_ind))

                plt.figure(100, figsize=(len(crops)*5, 5))
                plt.subplot(1, len(crops), crop_ind+1)
                plt.imshow(template)
                plt.title(str(crop_ind))
                plt.suptitle('Templates')

                plt.figure(101, figsize=(len(crops)*5, 5))
                plt.subplot(1, len(crops), crop_ind+1)
                plt.imshow(target)
                plt.title(str(crop_ind))
                plt.suptitle('Targets')

        print('Done saving crops')
        plt.figure()
        plt.plot(shifts.T, alpha=0.5)
        plt.plot(np.mean(shifts, axis=0), 'k')
        plt.ylabel('Pixel shift')
        plt.xlabel('Time [frames]')
        plt.title('Average pixel shift')
        print('Plotting average shifts.')
        plt.savefig(self._out_path + '/shift_average.pdf')

        plt.figure(100)
        plt.savefig(self._out_path + '/shift_templates.png')

        plt.figure(101)
        plt.savefig(self._out_path + '/shift_targets.png')

        np.savez(self._shifts_file, inds=shifts)
예제 #6
0
    def crop_stack(self, do_remove_led_frames=False, n_frames=None,
                   do_motion_correct=True, LED_buffer=2, led_std=3):
        """
        Loads the crop rois that were saved out using select_crop_rois().
        Loads the raw video. Saves out to a bigtiff file.
        Note: to import a bigtiff into ImageJ. Use File->Import->BioFormats
        :param do_remove_led_frames: bool. If true, will find indices
                    of frames where LED turns on. When saving out video
                    these frames will be replaced by neighboring frames.
                    Additionally, will save out a file that contains
                    the LED frame times.
        :param LED_buffer:
            number of frames before and after LED frame to remove.
        :param led_std:
            float how many standard deviations above baseline to use
                        as threshold for finding an led frame.

        :return: Nothing.
        """
        if os.path.isfile(self._roi_coords_file):
            coord_file = np.load(self._roi_coords_file)
        else:
            raise FileNotFoundError(
                'ROI coordinates file has not been generated.')
        slices = coord_file['coords']
        # loc_names = coord_file['loc_names']
        loc_names = self.vid_names
        print(slices)

        # Load up the raw data from individual .tif files
        stack = iio.load_raw_data(
            self._data_path, sub_seq=None, print_freq=100)

        if do_remove_led_frames:
            print('Finding LED frames.')
            # Find indices of LED frames.
            ledx = slice(10, 100)
            ledy = slice(700, 900)
            avg_trace = np.squeeze(
                np.mean(np.mean(stack[:, ledx, ledy], 2), 1))
            # Only look at a segment in the middle - in case light was
            # turned off at the beginning or end of acquisition.
            trace_segment = avg_trace[
                int(len(avg_trace)/4):int(len(avg_trace)/2)]
            threshed_avg = (avg_trace > (np.std(trace_segment)*led_std +
                            np.percentile(trace_segment, 10))).astype(int)
            led_peak_frames = peakutils.indexes(threshed_avg, min_dist=30)
            np.savez(self._led_frame_file, inds=led_peak_frames)

            # Replace LED frames with preceding frame.
            for ii, frame in enumerate(led_peak_frames):

                nf = LED_buffer
                ind = 1.0
                for ff in np.arange(-nf+1, nf):
                    stack[frame+ff, :, :] = (
                        stack[frame-nf, :, :]*(nf*2.0-ind)/(nf*2.0)
                        + stack[frame+nf, :, :]*(ind)/(nf*2.0))
                    ind = ind + 1
            avg_trace_new = np.squeeze(
                np.mean(np.mean(stack[:, ledx, ledy], 2), 1))

            # Plot average trace before and after removal.
            fig = plt.figure()
            p1, = plt.plot(avg_trace)
            plt.plot(led_peak_frames, avg_trace[led_peak_frames], 'ro')
            p2, = plt.plot(avg_trace_new, 'g')
            plt.legend([p1, p2], ['orig', 'filtered'])
            plt.savefig(self._out_path + '/led_peaks.pdf')

        # Save out each ROI to a separate bigtiff file.
        stack_depth = np.shape(stack)[0]
        if n_frames is not None:
            stack_depth = min(n_frames, stack_depth)

        for name, slice_ in zip(loc_names, slices):
            print('Saving out ROI: ' + name)
            roi_dir = os.path.join(self._out_path, str(name))
            if not os.path.exists(roi_dir):
                os.makedirs(roi_dir)
            roi_path = os.path.join(roi_dir, str(name)+'.tif')
            print('Saving ' + roi_path)
            t0 = time.time()
            with tifffile.TiffWriter(roi_path, bigtiff=True) as tif:
                template = stack[0, slice_[0], slice_[1]]
                for idx in range(stack_depth):
                    frame = stack[idx, slice_[0], slice_[1]]
                    if do_motion_correct:
                        shiftx, shifty = self.get_averaged_motion(
                            template, frame)
                        shifted_frame = self.shift_frame(frame, shiftx, shifty)
                    else:
                        shifted_frame = frame

                    tif.save(shifted_frame)

            print('Saved. Took ', time.time() - t0, ' seconds.')