Exemple #1
0
def _event_sending_fn(event_port, event_queue, debug=False):
    """

    Parameters
    ----------
    event_port :

    event_queue :

    debug :
         (Default value = False)

    Returns
    -------

    """
    bridge = Bridge(debug=debug)
    event_socket = bridge._connect_push(event_port)
    while True:
        events = event_queue.get(block=True)
        if debug:
            print("got event(s):", events)
        if events is None:
            # Poison, time to shut down
            event_socket.send({"events": [{"special": "acquisition-end"}]})
            event_socket.close()
            return
        event_socket.send(
            {"events": events if type(events) == list else [events]})
        if debug:
            print("sent events")
Exemple #2
0
def _acq_hook_startup_fn(pull_port, push_port, hook_connected_evt, event_queue,
                         hook_fn, debug):
    bridge = Bridge(debug=debug)

    push_socket = bridge._connect_push(pull_port)
    pull_socket = bridge._connect_pull(push_port)
    hook_connected_evt.set()

    while True:
        event_msg = pull_socket.receive()

        if 'special' in event_msg and event_msg['special'] == 'acquisition-end':
            push_socket.send({})
            push_socket.close()
            pull_socket.close()
            return
        else:
            params = signature(hook_fn).parameters
            if len(params) == 1 or len(params) == 3:
                try:
                    if len(params) == 1:
                        new_event_msg = hook_fn(event_msg)
                    elif len(params) == 3:
                        new_event_msg = hook_fn(event_msg, bridge, event_queue)
                except Exception as e:
                    warnings.warn(
                        'exception in acquisition hook: {}'.format(e))
                    continue
            else:
                raise Exception(
                    'Incorrect number of arguments for hook function. Must be 1 or 3'
                )

        push_socket.send(new_event_msg)
Exemple #3
0
def _acq_hook_startup_fn(pull_port, push_port, hook_connected_evt, event_queue, hook_fn, debug):
    """

    Parameters
    ----------
    pull_port :

    push_port :

    hook_connected_evt :

    event_queue :

    hook_fn :

    debug :


    Returns
    -------

    """
    bridge = Bridge(debug=debug)

    push_socket = bridge._connect_push(pull_port)
    pull_socket = bridge._connect_pull(push_port)
    hook_connected_evt.set()

    while True:
        event_msg = pull_socket.receive()

        if "special" in event_msg and event_msg["special"] == "acquisition-end":
            push_socket.send({})
            push_socket.close()
            pull_socket.close()
            return
        else:
            if "events" in event_msg.keys():
                event_msg = event_msg["events"]  # convert from sequence
            params = signature(hook_fn).parameters
            if len(params) == 1 or len(params) == 3:
                try:
                    if len(params) == 1:
                        new_event_msg = hook_fn(event_msg)
                    elif len(params) == 3:
                        new_event_msg = hook_fn(event_msg, bridge, event_queue)
                except Exception as e:
                    warnings.warn("exception in acquisition hook: {}".format(e))
                    continue
            else:
                raise Exception("Incorrect number of arguments for hook function. Must be 1 or 3")

        if isinstance(new_event_msg, list):
            new_event_msg = {
                "events": new_event_msg
            }  # convert back to the expected format for a sequence
        push_socket.send(new_event_msg)
Exemple #4
0
def _processor_startup_fn(pull_port, push_port, sockets_connected_evt, process_fn, event_queue, debug):
    bridge = Bridge(debug=debug)
    push_socket = bridge._connect_push(pull_port)
    pull_socket = bridge._connect_pull(push_port)
    if debug:
        print('image processing sockets connected')
    sockets_connected_evt.set()

    def process_and_sendoff(image_tags_tuple):
        if len(image_tags_tuple) != 2:
            raise Exception('If image is returned, it must be of the form (pixel, metadata)')
        if not image_tags_tuple[0].dtype == pixels.dtype:
            raise Exception('Processed image pixels must have same dtype as input image pixels, '
                            'but instead they were {} and {}'.format(image_tags_tuple[0].dtype, pixels.dtype))

        processed_img = {'pixels': serialize_array(image_tags_tuple[0]), 'metadata': image_tags_tuple[1]}
        push_socket.send(processed_img)

    while True:
        message = None
        while message is None:
            message = pull_socket.receive(timeout=30) #check for new message

        if 'special' in message and message['special'] == 'finished':
            push_socket.send(message) #Continue propagating the finihsed signal
            push_socket.close()
            pull_socket.close()
            return

        metadata = message['metadata']
        pixels = deserialize_array(message['pixels'])
        image = np.reshape(pixels, [metadata['Height'], metadata['Width']])

        params = signature(process_fn).parameters
        if len(params) == 2 or len(params) == 4:
            processed = None
            try:
                if len(params) == 2:
                    processed = process_fn(image, metadata)
                elif len(params) == 4:
                    processed = process_fn(image, metadata, bridge, event_queue)
            except Exception as e:
                warnings.warn('exception in image processor: {}'.format(e))
                continue
        else:
            raise Exception('Incorrect number of arguments for image processing function, must be 2 or 4')

        if processed is None:
            continue

        if type(processed) == list:
            for image in processed:
                process_and_sendoff(image)
        else:
            process_and_sendoff(processed)
Exemple #5
0
def _event_sending_fn(event_port, event_queue, debug=False):
    bridge = Bridge(debug=debug)
    event_socket = bridge._connect_push(event_port)
    while True:
        events = event_queue.get(block=True)
        if debug:
            print('got event(s):', events)
        if events is None:
            # Poison, time to shut down
            event_socket.send({'events': [{'special': 'acquisition-end'}]})
            event_socket.close()
            return
        event_socket.send({'events': events if type(events) == list else [events]})
        if debug:
            print('sent events')
Exemple #6
0
def _storage_monitor_fn(dataset,
                        storage_monitor_push_port,
                        connected_event,
                        callback_fn,
                        debug=False):
    bridge = Bridge(debug=debug)
    monitor_socket = bridge._connect_pull(storage_monitor_push_port)

    connected_event.set()

    while True:
        message = monitor_socket.receive()

        if "finished" in message:
            # Poison, time to shut down
            monitor_socket.close()
            return

        index_entry = message["index_entry"]
        axes = dataset._add_index_entry(index_entry)

        if callback_fn is not None:
            callback_fn(axes)
Exemple #7
0
def _storage_monitor_fn(dataset,
                        storage_monitor_push_port,
                        connected_event,
                        callback_fn,
                        debug=False):
    #TODO: might need to add in support for doing this on a different port, if Acquistiion/bridge is not on default port
    with Bridge(debug=debug) as bridge:
        monitor_socket = bridge._connect_pull(storage_monitor_push_port)

        connected_event.set()

        while True:
            message = monitor_socket.receive()

            if "finished" in message:
                # Poison, time to shut down
                monitor_socket.close()
                return

            index_entry = message["index_entry"]
            axes = dataset._add_index_entry(index_entry)

            if callback_fn is not None:
                callback_fn(axes)
Exemple #8
0
    def __init__(self,
                 directory=None,
                 name=None,
                 image_process_fn=None,
                 pre_hardware_hook_fn=None,
                 post_hardware_hook_fn=None,
                 tile_overlap=None,
                 magellan_acq_index=None,
                 process=True,
                 debug=False):
        """
        :param directory: saving directory for this acquisition. Required unless an image process function will be
            implemented that diverts images from saving
        :type directory: str
        :param name: Saving name for the acquisition. Required unless an image process function will be
            implemented that diverts images from saving
        :type name: str
        :param image_process_fn: image processing function that will be called on each image that gets acquired.
            Can either take two arguments (image, metadata) where image is a numpy array and metadata is a dict
            containing the corresponding iamge metadata. Or a 4 argument version is accepted, which accepts (image,
            metadata, bridge, queue), where bridge and queue are an instance of the pycromanager.acquire.Bridge
            object for the purposes of interacting with arbitrary code on the Java side (such as the micro-manager
            core), and queue is a Queue objects that holds upcomning acquisition events. Both version must either
            return
        :param pre_hardware_hook_fn: hook function that will be run just before the hardware is updated before acquiring
            a new image. Accepts either one argument (the current acquisition event) or three arguments (current event,
            bridge, event Queue)
        :param post_hardware_hook_fn: hook function that will be run just before the hardware is updated before acquiring
            a new image. Accepts either one argument (the current acquisition event) or three arguments (current event,
            bridge, event Queue)
        :param tile_overlap: If given, XY tiles will be laid out in a grid and multi-resolution saving will be
            actived. Argument can be a two element tuple describing the pixel overlaps between adjacent
            tiles. i.e. (pixel_overlap_x, pixel_overlap_y), or an integer to use the same overlap for both.
            For these features to work, the current hardware configuration must have a valid affine transform
            between camera coordinates and XY stage coordinates
        :type tile_overlap: tuple, int
        :param magellan_acq_index: run this acquisition using the settings specified at this position in the main
            GUI of micro-magellan (micro-manager plugin). This index starts at 0
        :type magellan_acq_index: int
        :param process: (Experimental) use multiprocessing instead of multithreading for acquisition hooks and image
            processors
        :type process: boolean
        :param debug: print debugging stuff
        :type debug: boolean
        """
        self.bridge = Bridge(debug=debug)
        self._debug = debug
        self._dataset = None

        if directory is not None:
            # Expend ~ in path
            directory = os.path.expanduser(directory)
            # If path is relative, retain knowledge of the current working directory
            directory = os.path.abspath(directory)

        if magellan_acq_index is not None:
            magellan_api = self.bridge.get_magellan()
            self._remote_acq = magellan_api.create_acquisition(
                magellan_acq_index)
            self._event_queue = None
        else:
            # Create thread safe queue for events so they can be passed from multiple processes
            self._event_queue = multiprocessing.Queue()
            core = self.bridge.get_core()
            acq_factory = self.bridge.construct_java_object(
                'org.micromanager.remote.RemoteAcquisitionFactory',
                args=[core])

            #TODO: could add hiding viewer as an option
            show_viewer = directory is not None and name is not None
            if tile_overlap is None:
                #argument placeholders, these wont actually be used
                x_overlap = 0
                y_overlap = 0
            else:
                if type(tile_overlap) is tuple:
                    x_overlap, y_overlap = tile_overlap
                else:
                    x_overlap = tile_overlap
                    y_overlap = tile_overlap

            self._remote_acq = acq_factory.create_acquisition(
                directory, name, show_viewer, tile_overlap is not None,
                x_overlap, y_overlap)

        if image_process_fn is not None:
            processor = self.bridge.construct_java_object(
                'org.micromanager.remote.RemoteImageProcessor')
            self._remote_acq.add_image_processor(processor)
            self._start_processor(processor,
                                  image_process_fn,
                                  self._event_queue,
                                  process=process)

        if pre_hardware_hook_fn is not None:
            hook = self.bridge.construct_java_object(
                'org.micromanager.remote.RemoteAcqHook')
            self._start_hook(hook,
                             pre_hardware_hook_fn,
                             self._event_queue,
                             process=process)
            self._remote_acq.add_hook(hook,
                                      self._remote_acq.BEFORE_HARDWARE_HOOK,
                                      args=[self._remote_acq])
        if post_hardware_hook_fn is not None:
            hook = self.bridge.construct_java_object(
                'org.micromanager.remote.RemoteAcqHook',
                args=[self._remote_acq])
            self._start_hook(hook,
                             post_hardware_hook_fn,
                             self._event_queue,
                             process=process)
            self._remote_acq.add_hook(hook,
                                      self._remote_acq.AFTER_HARDWARE_HOOK)

        self._remote_acq.start()

        if magellan_acq_index is None:
            self.event_port = self._remote_acq.get_event_port()

            self.event_process = multiprocessing.Process(
                target=_event_sending_fn,
                args=(self.event_port, self._event_queue, self._debug),
                name='Event sending')
            # if multiprocessing else threading.Thread(target=event_sending_fn, args=(), name='Event sending')
            self.event_process.start()
Exemple #9
0
class Acquisition(object):
    def __init__(self,
                 directory=None,
                 name=None,
                 image_process_fn=None,
                 pre_hardware_hook_fn=None,
                 post_hardware_hook_fn=None,
                 tile_overlap=None,
                 magellan_acq_index=None,
                 process=True,
                 debug=False):
        """
        :param directory: saving directory for this acquisition. Required unless an image process function will be
            implemented that diverts images from saving
        :type directory: str
        :param name: Saving name for the acquisition. Required unless an image process function will be
            implemented that diverts images from saving
        :type name: str
        :param image_process_fn: image processing function that will be called on each image that gets acquired.
            Can either take two arguments (image, metadata) where image is a numpy array and metadata is a dict
            containing the corresponding iamge metadata. Or a 4 argument version is accepted, which accepts (image,
            metadata, bridge, queue), where bridge and queue are an instance of the pycromanager.acquire.Bridge
            object for the purposes of interacting with arbitrary code on the Java side (such as the micro-manager
            core), and queue is a Queue objects that holds upcomning acquisition events. Both version must either
            return
        :param pre_hardware_hook_fn: hook function that will be run just before the hardware is updated before acquiring
            a new image. Accepts either one argument (the current acquisition event) or three arguments (current event,
            bridge, event Queue)
        :param post_hardware_hook_fn: hook function that will be run just before the hardware is updated before acquiring
            a new image. Accepts either one argument (the current acquisition event) or three arguments (current event,
            bridge, event Queue)
        :param tile_overlap: If given, XY tiles will be laid out in a grid and multi-resolution saving will be
            actived. Argument can be a two element tuple describing the pixel overlaps between adjacent
            tiles. i.e. (pixel_overlap_x, pixel_overlap_y), or an integer to use the same overlap for both.
            For these features to work, the current hardware configuration must have a valid affine transform
            between camera coordinates and XY stage coordinates
        :type tile_overlap: tuple, int
        :param magellan_acq_index: run this acquisition using the settings specified at this position in the main
            GUI of micro-magellan (micro-manager plugin). This index starts at 0
        :type magellan_acq_index: int
        :param process: (Experimental) use multiprocessing instead of multithreading for acquisition hooks and image
            processors
        :type process: boolean
        :param debug: print debugging stuff
        :type debug: boolean
        """
        self.bridge = Bridge(debug=debug)
        self._debug = debug
        self._dataset = None

        if directory is not None:
            # Expend ~ in path
            directory = os.path.expanduser(directory)
            # If path is relative, retain knowledge of the current working directory
            directory = os.path.abspath(directory)

        if magellan_acq_index is not None:
            magellan_api = self.bridge.get_magellan()
            self._remote_acq = magellan_api.create_acquisition(
                magellan_acq_index)
            self._event_queue = None
        else:
            # Create thread safe queue for events so they can be passed from multiple processes
            self._event_queue = multiprocessing.Queue()
            core = self.bridge.get_core()
            acq_factory = self.bridge.construct_java_object(
                'org.micromanager.remote.RemoteAcquisitionFactory',
                args=[core])

            #TODO: could add hiding viewer as an option
            show_viewer = directory is not None and name is not None
            if tile_overlap is None:
                #argument placeholders, these wont actually be used
                x_overlap = 0
                y_overlap = 0
            else:
                if type(tile_overlap) is tuple:
                    x_overlap, y_overlap = tile_overlap
                else:
                    x_overlap = tile_overlap
                    y_overlap = tile_overlap

            self._remote_acq = acq_factory.create_acquisition(
                directory, name, show_viewer, tile_overlap is not None,
                x_overlap, y_overlap)

        if image_process_fn is not None:
            processor = self.bridge.construct_java_object(
                'org.micromanager.remote.RemoteImageProcessor')
            self._remote_acq.add_image_processor(processor)
            self._start_processor(processor,
                                  image_process_fn,
                                  self._event_queue,
                                  process=process)

        if pre_hardware_hook_fn is not None:
            hook = self.bridge.construct_java_object(
                'org.micromanager.remote.RemoteAcqHook')
            self._start_hook(hook,
                             pre_hardware_hook_fn,
                             self._event_queue,
                             process=process)
            self._remote_acq.add_hook(hook,
                                      self._remote_acq.BEFORE_HARDWARE_HOOK,
                                      args=[self._remote_acq])
        if post_hardware_hook_fn is not None:
            hook = self.bridge.construct_java_object(
                'org.micromanager.remote.RemoteAcqHook',
                args=[self._remote_acq])
            self._start_hook(hook,
                             post_hardware_hook_fn,
                             self._event_queue,
                             process=process)
            self._remote_acq.add_hook(hook,
                                      self._remote_acq.AFTER_HARDWARE_HOOK)

        self._remote_acq.start()

        if magellan_acq_index is None:
            self.event_port = self._remote_acq.get_event_port()

            self.event_process = multiprocessing.Process(
                target=_event_sending_fn,
                args=(self.event_port, self._event_queue, self._debug),
                name='Event sending')
            # if multiprocessing else threading.Thread(target=event_sending_fn, args=(), name='Event sending')
            self.event_process.start()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self._event_queue is not None:  #magellan acquisitions dont have this
            # this should shut down storage and viewer as apporpriate
            self._event_queue.put(None)
        #now wait on it to finish
        self.await_completion()

    def get_dataset(self):
        """
        Return a :class:`~pycromanager.data.Dataset` object that has access to the underlying pixels

        :return: :class:`~pycromanager.data.Dataset` corresponding to this acquisition
        """
        if self._dataset is None:
            self._dataset = Dataset(
                remote_storage=self._remote_acq.get_storage())
        return self._dataset

    def await_completion(self):
        """
        Wait for acquisition to finish and resources to be cleaned up
        """
        while (not self._remote_acq.is_finished()):
            time.sleep(0.1)

    def acquire(self, events):
        """
        Submit an event or a list of events for acquisition. Optimizations (i.e. taking advantage of
        hardware synchronization, where available), will take place across this list of events, but not
        over multiple calls of this method. A single event is a python dictionary with a specific structure

        :param events: single event (i.e. a dictionary) or a list of events
        """
        self._event_queue.put(events)

    def _start_hook(self, remote_hook, remote_hook_fn, event_queue, process):
        hook_connected_evt = multiprocessing.Event(
        ) if process else threading.Event()

        pull_port = remote_hook.get_pull_port()
        push_port = remote_hook.get_push_port()

        hook_thread = multiprocessing.Process(
            target=_acq_hook_startup_fn,
            name='AcquisitionHook',
            args=(pull_port, push_port, hook_connected_evt, event_queue,
                  remote_hook_fn, self._debug))
        # if process else threading.Thread(target=_acq_hook_fn, args=(), name='AcquisitionHook')
        hook_thread.start()

        hook_connected_evt.wait()  # wait for push/pull sockets to connect

    def _start_processor(self, processor, process_fn, event_queue, process):
        # this must start first
        processor.start_pull()

        sockets_connected_evt = multiprocessing.Event(
        ) if process else threading.Event()

        pull_port = processor.get_pull_port()
        push_port = processor.get_push_port()

        self.processor_thread = multiprocessing.Process(
            target=_processor_startup_fn,
            args=(pull_port, push_port, sockets_connected_evt, process_fn,
                  event_queue, self._debug),
            name='ImageProcessor')
        # if multiprocessing else threading.Thread(target=other_thread_fn, args=(),  name='ImageProcessor')
        self.processor_thread.start()

        sockets_connected_evt.wait()  # wait for push/pull sockets to connect
        processor.start_push()
Exemple #10
0
    def __init__(
        self,
        directory=None,
        name=None,
        image_process_fn=None,
        event_generation_hook_fn=None,
        pre_hardware_hook_fn=None,
        post_hardware_hook_fn=None,
        post_camera_hook_fn=None,
        show_display=True,
        tile_overlap=None,
        max_multi_res_index=None,
        magellan_acq_index=None,
        magellan_explore=False,
        process=False,
        debug=False,
    ):
        """
        Parameters
        ----------
        directory : str
            saving directory for this acquisition. Required unless an image process function will be
            implemented that diverts images from saving
        name : str
            Saving name for the acquisition. Required unless an image process function will be
            implemented that diverts images from saving
        image_process_fn : Callable
            image processing function that will be called on each image that gets acquired.
            Can either take two arguments (image, metadata) where image is a numpy array and metadata is a dict
            containing the corresponding iamge metadata. Or a 4 argument version is accepted, which accepts (image,
            metadata, bridge, queue), where bridge and queue are an instance of the pycromanager.acquire.Bridge
            object for the purposes of interacting with arbitrary code on the Java side (such as the micro-manager
            core), and queue is a Queue objects that holds upcomning acquisition events. Both version must either
            return
        event_generation_hook_fn : Callable
            hook function that will as soon as acquisition events are generated (before hardware sequencing optimization
            in the acquisition engine. This is useful if one wants to modify acquisition events that they didn't generate
            (e.g. those generated by a GUI application). Accepts either one argument (the current acquisition event)
            or three arguments (current event, bridge, event Queue)
        pre_hardware_hook_fn : Callable
            hook function that will be run just before the hardware is updated before acquiring
            a new image. In the case of hardware sequencing, it will be run just before a sequence of instructions are
            dispatched to the hardware. Accepts either one argument (the current acquisition event) or three arguments
            (current event, bridge, event Queue)
        post_hardware_hook_fn : Callable
            hook function that will be run just before the hardware is updated before acquiring
            a new image. In the case of hardware sequencing, it will be run just after a sequence of instructions are
            dispatched to the hardware, but before the camera sequence has been started. Accepts either one argument
            (the current acquisition event) or three arguments (current event, bridge, event Queue)
        post_camera_hook_fn : Callable
            hook function that will be run just after the camera has been triggered to snapImage or
            startSequence. A common use case for this hook is when one want to send TTL triggers to the camera from an
            external timing device that synchronizes with other hardware. Accepts either one argument (the current
            acquisition event) or three arguments (current event, bridge, event Queue)
        tile_overlap : int or tuple of int
            If given, XY tiles will be laid out in a grid and multi-resolution saving will be
            actived. Argument can be a two element tuple describing the pixel overlaps between adjacent
            tiles. i.e. (pixel_overlap_x, pixel_overlap_y), or an integer to use the same overlap for both.
            For these features to work, the current hardware configuration must have a valid affine transform
            between camera coordinates and XY stage coordinates
        max_multi_res_index : int
            Maximum index to downsample to in multi-res pyramid mode (which is only active if a value for
            "tile_overlap" is passed in, or if running a Micro-Magellan acquisition). 0 is no downsampling,
            1 is downsampled up to 2x, 2 is downsampled up to 4x, etc. If not provided, it will be dynamically
            calculated and updated from data
        show_display : bool
            show the image viewer window
        magellan_acq_index : int
            run this acquisition using the settings specified at this position in the main
            GUI of micro-magellan (micro-manager plugin). This index starts at 0
        magellan_explore : bool
            Run a Micro-magellan explore acquisition
        process : bool
            Use multiprocessing instead of multithreading for acquisition hooks and image
            processors. This can be used to speed up CPU-bounded processing by eliminating bottlenecks
            caused by Python's Global Interpreter Lock, but also creates complications on Windows-based
            systems
        debug : bool
            whether to print debug messages
        """
        self.bridge = Bridge(debug=debug)
        self._debug = debug
        self._dataset = None

        if directory is not None:
            # Expend ~ in path
            directory = os.path.expanduser(directory)
            # If path is relative, retain knowledge of the current working directory
            directory = os.path.abspath(directory)

        if magellan_acq_index is not None:
            magellan_api = self.bridge.get_magellan()
            self._remote_acq = magellan_api.create_acquisition(
                magellan_acq_index)
            self._event_queue = None
        elif magellan_explore:
            magellan_api = self.bridge.get_magellan()
            self._remote_acq = magellan_api.create_explore_acquisition()
            self._event_queue = None
        else:
            # Create thread safe queue for events so they can be passed from multiple processes
            self._event_queue = multiprocessing.Queue(
            ) if process else queue.Queue()
            core = self.bridge.get_core()
            acq_factory = self.bridge.construct_java_object(
                "org.micromanager.remote.RemoteAcquisitionFactory",
                args=[core])

            show_viewer = show_display and (directory is not None
                                            and name is not None)
            if tile_overlap is None:
                # argument placeholders, these wont actually be used
                x_overlap = 0
                y_overlap = 0
            else:
                if type(tile_overlap) is tuple:
                    x_overlap, y_overlap = tile_overlap
                else:
                    x_overlap = tile_overlap
                    y_overlap = tile_overlap

            self._remote_acq = acq_factory.create_acquisition(
                directory,
                name,
                show_viewer,
                tile_overlap is not None,
                x_overlap,
                y_overlap,
                max_multi_res_index if max_multi_res_index is not None else -1,
            )
        storage = self._remote_acq.get_data_sink()
        if storage is not None:
            self.disk_location = storage.get_disk_location()

        if image_process_fn is not None:
            processor = self.bridge.construct_java_object(
                "org.micromanager.remote.RemoteImageProcessor")
            self._remote_acq.add_image_processor(processor)
            self._start_processor(processor,
                                  image_process_fn,
                                  self._event_queue,
                                  process=process)

        if event_generation_hook_fn is not None:
            hook = self.bridge.construct_java_object(
                "org.micromanager.remote.RemoteAcqHook",
                args=[self._remote_acq])
            self._start_hook(hook,
                             event_generation_hook_fn,
                             self._event_queue,
                             process=process)
            self._remote_acq.add_hook(hook,
                                      self._remote_acq.EVENT_GENERATION_HOOK)
        if pre_hardware_hook_fn is not None:
            hook = self.bridge.construct_java_object(
                "org.micromanager.remote.RemoteAcqHook",
                args=[self._remote_acq])
            self._start_hook(hook,
                             pre_hardware_hook_fn,
                             self._event_queue,
                             process=process)
            self._remote_acq.add_hook(hook,
                                      self._remote_acq.BEFORE_HARDWARE_HOOK)
        if post_hardware_hook_fn is not None:
            hook = self.bridge.construct_java_object(
                "org.micromanager.remote.RemoteAcqHook",
                args=[self._remote_acq])
            self._start_hook(hook,
                             post_hardware_hook_fn,
                             self._event_queue,
                             process=process)
            self._remote_acq.add_hook(hook,
                                      self._remote_acq.AFTER_HARDWARE_HOOK)
        if post_camera_hook_fn is not None:
            hook = self.bridge.construct_java_object(
                "org.micromanager.remote.RemoteAcqHook",
                args=[self._remote_acq])
            self._start_hook(hook,
                             post_camera_hook_fn,
                             self._event_queue,
                             process=process)
            self._remote_acq.add_hook(hook, self._remote_acq.AFTER_CAMERA_HOOK)

        self._remote_acq.start()

        if magellan_acq_index is None and not magellan_explore:
            self.event_port = self._remote_acq.get_event_port()

            self.event_process = threading.Thread(
                target=_event_sending_fn,
                args=(self.event_port, self._event_queue, self._debug),
                name="Event sending",
            )
            self.event_process.start()
Exemple #11
0
    def __init__(self, dataset_path=None, full_res_only=True, remote_storage=None):
        self._tile_width = None
        self._tile_height = None
        if remote_storage is not None:
            # this dataset is a view of an active acquisiiton. The storage exists on the java side
            self._remote_storage = remote_storage
            self._bridge = Bridge()
            smd = self._remote_storage.get_summary_metadata()
            if "GridPixelOverlapX" in smd.keys():
                self._tile_width = smd["Width"] - smd["GridPixelOverlapX"]
                self._tile_height = smd["Height"] - smd["GridPixelOverlapY"]
            return
        else:
            self._remote_storage = None

        self.path = dataset_path
        res_dirs = [
            dI for dI in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, dI))
        ]
        # map from downsample factor to datset
        self.res_levels = {}
        if "Full resolution" not in res_dirs:
            raise Exception(
                "Couldn't find full resolution directory. Is this the correct path to a dataset?"
            )
        num_tiffs = 0
        count = 0
        for res_dir in res_dirs:
            for file in os.listdir(os.path.join(dataset_path, res_dir)):
                if file.endswith(".tif"):
                    num_tiffs += 1
        for res_dir in res_dirs:
            if full_res_only and res_dir != "Full resolution":
                continue
            res_dir_path = os.path.join(dataset_path, res_dir)
            res_level = _ResolutionLevel(res_dir_path, count, num_tiffs)
            if res_dir == "Full resolution":
                self.res_levels[0] = res_level
                # get summary metadata and index tree from full resolution image
                self.summary_metadata = res_level.summary_metadata

                self.overlap = (
                    np.array(
                        [
                            self.summary_metadata["GridPixelOverlapY"],
                            self.summary_metadata["GridPixelOverlapX"],
                        ]
                    )
                    if "GridPixelOverlapY" in self.summary_metadata
                    else None
                )

                self.axes = {}
                for axes_combo in res_level.index.keys():
                    for axis, position in axes_combo:
                        if axis not in self.axes.keys():
                            self.axes[axis] = set()
                        self.axes[axis].add(position)

                # figure out the mapping of channel name to position by reading image metadata
                print("\rReading channel names...", end="")
                if self._CHANNEL_AXIS in self.axes.keys():
                    self._channel_names = {}
                    for key in res_level.index.keys():
                        axes = {axis: position for axis, position in key}
                        if (
                            self._CHANNEL_AXIS in axes.keys()
                            and axes[self._CHANNEL_AXIS] not in self._channel_names.values()
                        ):
                            channel_name = res_level.read_metadata(axes)["Channel"]
                            self._channel_names[channel_name] = axes[self._CHANNEL_AXIS]
                        if len(self._channel_names.values()) == len(self.axes[self._CHANNEL_AXIS]):
                            break
                print("\rFinished reading channel names", end="")

                # remove axes with no variation
                single_axes = [axis for axis in self.axes if len(self.axes[axis]) == 1]
                for axis in single_axes:
                    del self.axes[axis]

                # If the dataset uses XY stitching, map out the row and col indices
                if (
                    "TiledImageStorage" in self.summary_metadata
                    and self.summary_metadata["TiledImageStorage"]
                ):
                    # Make an n x 2 array with nan's where no positions actually exist
                    pass

            else:
                self.res_levels[int(np.log2(int(res_dir.split("x")[1])))] = res_level

        # get information about image width and height, assuming that they are consistent for whole dataset
        # (which isn't strictly neccesary)
        first_index = list(self.res_levels[0].index.values())[0]
        if first_index["pixel_type"] == _MultipageTiffReader.EIGHT_BIT_RGB:
            self.bytes_per_pixel = 3
            self.dtype = np.uint8
        elif first_index["pixel_type"] == _MultipageTiffReader.EIGHT_BIT:
            self.bytes_per_pixel = 1
            self.dtype = np.uint8
        elif first_index["pixel_type"] == _MultipageTiffReader.SIXTEEN_BIT:
            self.bytes_per_pixel = 2
            self.dtype = np.uint16

        self.image_width = first_index["image_width"]
        self.image_height = first_index["image_height"]
        if "GridPixelOverlapX" in self.summary_metadata:
            self._tile_width = self.image_width - self.summary_metadata["GridPixelOverlapX"]
            self._tile_height = self.image_height - self.summary_metadata["GridPixelOverlapY"]

        print("\rDataset opened                ")
Exemple #12
0
class Dataset:
    """Class that opens a single NDTiffStorage dataset"""

    _POSITION_AXIS = "position"
    _Z_AXIS = "z"
    _TIME_AXIS = "time"
    _CHANNEL_AXIS = "channel"

    def __init__(self, dataset_path=None, full_res_only=True, remote_storage=None):
        self._tile_width = None
        self._tile_height = None
        if remote_storage is not None:
            # this dataset is a view of an active acquisiiton. The storage exists on the java side
            self._remote_storage = remote_storage
            self._bridge = Bridge()
            smd = self._remote_storage.get_summary_metadata()
            if "GridPixelOverlapX" in smd.keys():
                self._tile_width = smd["Width"] - smd["GridPixelOverlapX"]
                self._tile_height = smd["Height"] - smd["GridPixelOverlapY"]
            return
        else:
            self._remote_storage = None

        self.path = dataset_path
        res_dirs = [
            dI for dI in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, dI))
        ]
        # map from downsample factor to datset
        self.res_levels = {}
        if "Full resolution" not in res_dirs:
            raise Exception(
                "Couldn't find full resolution directory. Is this the correct path to a dataset?"
            )
        num_tiffs = 0
        count = 0
        for res_dir in res_dirs:
            for file in os.listdir(os.path.join(dataset_path, res_dir)):
                if file.endswith(".tif"):
                    num_tiffs += 1
        for res_dir in res_dirs:
            if full_res_only and res_dir != "Full resolution":
                continue
            res_dir_path = os.path.join(dataset_path, res_dir)
            res_level = _ResolutionLevel(res_dir_path, count, num_tiffs)
            if res_dir == "Full resolution":
                # TODO: might want to move this within the resolution level class to facilitate loading pyramids
                self.res_levels[0] = res_level
                # get summary metadata and index tree from full resolution image
                self.summary_metadata = res_level.reader_list[0].summary_md
                self.rgb = res_level.reader_list[0].rgb
                self._channel_names = {}  # read them from image metadata
                self._extra_axes_to_storage_channel = {}

                # store some fields explicitly for easy access
                self.dtype = (
                    np.uint16 if self.summary_metadata["PixelType"] == "GRAY16" else np.uint8
                )
                self.pixel_size_xy_um = self.summary_metadata["PixelSize_um"]
                self.pixel_size_z_um = (
                    self.summary_metadata["z-step_um"]
                    if "z-step_um" in self.summary_metadata
                    else None
                )
                self.image_width = res_level.reader_list[0].width
                self.image_height = res_level.reader_list[0].height
                self.overlap = (
                    np.array(
                        [
                            self.summary_metadata["GridPixelOverlapY"],
                            self.summary_metadata["GridPixelOverlapX"],
                        ]
                    )
                    if "GridPixelOverlapY" in self.summary_metadata
                    else None
                )
                c_z_t_p_tree = res_level.reader_tree
                # the c here refers to super channels, encompassing all non-tzp axes in addition to channels
                # map of axis names to values where data exists
                self.axes = {
                    self._Z_AXIS: set(),
                    self._TIME_AXIS: set(),
                    self._POSITION_AXIS: set(),
                    self._CHANNEL_AXIS: set(),
                }
                for c in c_z_t_p_tree.keys():
                    for z in c_z_t_p_tree[c]:
                        self.axes[self._Z_AXIS].add(z)
                        for t in c_z_t_p_tree[c][z]:
                            self.axes[self._TIME_AXIS].add(t)
                            for p in c_z_t_p_tree[c][z][t]:
                                self.axes[self._POSITION_AXIS].add(p)
                                if c not in self.axes["channel"]:
                                    metadata = self.res_levels[0].read_metadata(
                                        channel_index=c, z_index=z, t_index=t, pos_index=p
                                    )
                                    current_axes = metadata["Axes"]
                                    non_zpt_axes = {}
                                    for axis in current_axes:
                                        if axis not in [
                                            self._Z_AXIS,
                                            self._TIME_AXIS,
                                            self._POSITION_AXIS,
                                        ]:
                                            if axis not in self.axes:
                                                self.axes[axis] = set()
                                            self.axes[axis].add(current_axes[axis])
                                            non_zpt_axes[axis] = current_axes[axis]

                                    self._channel_names[metadata["Channel"]] = non_zpt_axes[
                                        self._CHANNEL_AXIS
                                    ]
                                    self._extra_axes_to_storage_channel[
                                        frozenset(non_zpt_axes.items())
                                    ] = c

                # remove axes with no variation
                single_axes = [axis for axis in self.axes if len(self.axes[axis]) == 1]
                for axis in single_axes:
                    del self.axes[axis]

                if "position" in self.axes and "GridPixelOverlapX" in self.summary_metadata:
                    # Make an n x 2 array with nan's where no positions actually exist
                    self.row_col_array = np.ones((len(self.axes["position"]), 2)) * np.nan
                    self.position_centers = np.ones((len(self.axes["position"]), 2)) * np.nan
                    row_cols = []
                    for c_index in c_z_t_p_tree.keys():
                        for z_index in c_z_t_p_tree[c_index].keys():
                            for t_index in c_z_t_p_tree[c_index][z_index].keys():
                                p_indices = c_z_t_p_tree[c_index][z_index][t_index].keys()
                                for p_index in p_indices:
                                    # in case position index doesn't start at 0, pos_index_index is index
                                    # into self.axes['position']
                                    pos_index_index = list(self.axes["position"]).index(p_index)
                                    if not np.isnan(self.row_col_array[pos_index_index, 0]):
                                        # already figured this one out
                                        continue
                                    if not res_level.check_ifd(
                                        channel_index=c_index,
                                        z_index=z_index,
                                        t_index=t_index,
                                        pos_index=p_index,
                                    ):
                                        row_cols.append(
                                            np.array([np.nan, np.nan])
                                        )  # this position is corrupted
                                        warnings.warn(
                                            "Corrupted image p: {} c: {} t: {} z: {}".format(
                                                p_index, c_index, t_index, z_index
                                            )
                                        )
                                        row_cols.append(np.array([np.nan, np.nan]))
                                    else:
                                        md = res_level.read_metadata(
                                            channel_index=c_index,
                                            pos_index=p_index,
                                            t_index=t_index,
                                            z_index=z_index,
                                        )
                                        self.row_col_array[pos_index_index] = np.array(
                                            [md["GridRowIndex"], md["GridColumnIndex"]]
                                        )
                                        self.position_centers[pos_index_index] = np.array(
                                            [
                                                md["XPosition_um_Intended"],
                                                md["YPosition_um_Intended"],
                                            ]
                                        )

            else:
                self.res_levels[int(np.log2(int(res_dir.split("x")[1])))] = res_level
        print("\rDataset opened")

    def as_array(self, stitched=False, verbose=False):
        """
        Read all data image data as one big Dask array with last two axes as y, x and preceeding axes depending on data.
        The dask array is made up of memory-mapped numpy arrays, so the dataset does not need to be able to fit into RAM.
        If the data doesn't fully fill out the array (e.g. not every z-slice collected at every time point), zeros will
        be added automatically.

        To convert data into a numpy array, call np.asarray() on the returned result. However, doing so will bring the
        data into RAM, so it may be better to do this on only a slice of the array at a time.

        Parameters
        ----------
        stitched : bool
            If true and tiles were acquired in a grid, lay out adjacent tiles next to one another (Default value = False)
        verbose : bool
            If True print updates on progress loading the image
        Returns
        -------
        dataset : dask array
        """
        if self._remote_storage is not None:
            raise Exception("Method not yet implemented for in progress acquisitions")
        self._empty_tile = (
            np.zeros((self.image_height, self.image_width), self.dtype)
            if not self.rgb
            else np.zeros((self.image_height, self.image_width, 3), self.dtype)
        )
        self._count = 1
        total = np.prod([len(v) for v in self.axes.values()])

        def recurse_axes(loop_axes, point_axes):
            if len(loop_axes.values()) == 0:
                if verbose:
                    print("\rAdding data chunk {} of {}".format(self._count, total), end="")
                self._count += 1
                if None not in point_axes.values() and self.has_image(**point_axes):
                    return self.read_image(**point_axes, memmapped=True)
                else:
                    # return np.zeros((self.image_height, self.image_width), self.dtype)
                    return self._empty_tile
            else:
                # do position first because it makes stitching faster
                axis = (
                    "position"
                    if "position" in loop_axes.keys() and stitched
                    else list(loop_axes.keys())[0]
                )
                remaining_axes = loop_axes.copy()
                del remaining_axes[axis]
                if axis == "position" and stitched:
                    # Stitch tiles acquired in a grid
                    self.half_overlap = self.overlap[0] // 2

                    # get spatial layout of position indices
                    zero_min_row_col = self.row_col_array - np.nanmin(self.row_col_array, axis=0)
                    row_col_mat = np.nan * np.ones(
                        [
                            int(np.nanmax(zero_min_row_col[:, 0])) + 1,
                            int(np.nanmax(zero_min_row_col[:, 1])) + 1,
                        ]
                    )
                    positions_indices = np.array(list(loop_axes["position"]))
                    rows = zero_min_row_col[positions_indices][:, 0]
                    cols = zero_min_row_col[positions_indices][:, 1]
                    # mask in case some positions were corrupted
                    mask = np.logical_not(np.isnan(rows))
                    row_col_mat[
                        rows[mask].astype(np.int), cols[mask].astype(np.int)
                    ] = positions_indices[mask]

                    blocks = []
                    for row in row_col_mat:
                        blocks.append([])
                        for p_index in row:
                            if verbose:
                                print(
                                    "\rAdding data chunk {} of {}".format(self._count, total),
                                    end="",
                                )
                            valed_axes = point_axes.copy()
                            valed_axes[axis] = int(p_index) if not np.isnan(p_index) else None
                            blocks[-1].append(da.stack(recurse_axes(remaining_axes, valed_axes)))

                    if self.rgb:
                        stitched_array = np.concatenate(
                            [
                                np.concatenate(row, axis=len(blocks[0][0].shape) - 2)
                                for row in blocks
                            ],
                            axis=len(blocks[0][0].shape) - 3,
                        )
                    else:
                        stitched_array = da.block(blocks)
                    return stitched_array
                else:
                    blocks = []
                    for val in loop_axes[axis]:
                        valed_axes = point_axes.copy()
                        valed_axes[axis] = val
                        blocks.append(recurse_axes(remaining_axes, valed_axes))
                    return blocks

        blocks = recurse_axes(self.axes, {})

        if verbose:
            print(
                " Stacking tiles"
            )  # extra space otherwise there is no space after the "Adding data chunk {} {}"
        array = da.stack(blocks)
        if verbose:
            print("\rDask array opened")
        return array

    def _convert_to_storage_axes(self, axes, channel_name=None):
        """Convert an abitrary set of axes to cztp axes as in the underlying storage

        Parameters
        ----------
        axes
        channel_name
        """
        if channel_name is not None:
            if channel_name not in self._channel_names.keys():
                raise Exception("Channel name {} not found".format(channel_name))
            axes[self._CHANNEL_AXIS] = self._channel_names[channel_name]
        if self._CHANNEL_AXIS not in axes:
            axes[self._CHANNEL_AXIS] = 0

        z_index = axes[self._Z_AXIS] if self._Z_AXIS in axes else 0
        t_index = axes[self._TIME_AXIS] if self._TIME_AXIS in axes else 0
        p_index = axes[self._POSITION_AXIS] if self._POSITION_AXIS in axes else 0

        non_zpt_axes = {
            key: axes[key]
            for key in axes.keys()
            if key not in [self._TIME_AXIS, self._POSITION_AXIS, self._Z_AXIS]
        }
        for axis in non_zpt_axes.keys():
            if axis not in self.axes.keys() and axis != "channel":
                raise Exception("Unknown axis: {}".format(axis))
        c_index = self._extra_axes_to_storage_channel[frozenset(non_zpt_axes.items())]
        return c_index, t_index, p_index, z_index

    def has_image(
        self,
        channel=None,
        z=None,
        time=None,
        position=None,
        channel_name=None,
        resolution_level=0,
        row=None,
        col=None,
        **kwargs
    ):
        """Check if this image is present in the dataset

        Parameters
        ----------
        channel : int
            index of the channel, if applicable (Default value = None)
        z : int
            index of z slice, if applicable (Default value = None)
        time : int
            index of the time point, if applicable (Default value = None)
        position : int
            index of the XY position, if applicable (Default value = None)
        channel_name : str
            Name of the channel. Overrides channel index if supplied (Default value = None)
        row : int
            index of tile row for XY tiled datasets (Default value = None)
        col : int
            index of tile col for XY tiled datasets (Default value = None)
        resolution_level :
            0 is full resolution, otherwise represents downampling of pixels
            at 2 ** (resolution_level) (Default value = 0)
        **kwargs
            Arbitrary keyword arguments

        Returns
        -------
        bool :
            indicating whether the dataset has an image matching the specifications
        """
        if channel is not None:
            kwargs["channel"] = channel
        if z is not None:
            kwargs["z"] = z
        if time is not None:
            kwargs["time"] = time
        if position is not None:
            kwargs["position"] = position

        if self._remote_storage is not None:
            axes = self._bridge.construct_java_object("java.util.HashMap")
            for key in kwargs.keys():
                axes.put(key, kwargs[key])
            if row is not None and col is not None:
                return self._remote_storage.has_tile_by_row_col(axes, resolution_level, row, col)
            else:
                return self._remote_storage.has_image(axes, resolution_level)

        if row is not None or col is not None:
            raise Exception("row col lookup not yet implmented for saved datasets")
            # self.row_col_array #TODO: find position index in here

        storage_c_index, t_index, p_index, z_index = self._convert_to_storage_axes(
            kwargs, channel_name=channel_name
        )
        c_z_t_p_tree = self.res_levels[resolution_level].reader_tree
        if (
            storage_c_index in c_z_t_p_tree
            and z_index in c_z_t_p_tree[storage_c_index]
            and t_index in c_z_t_p_tree[storage_c_index][z_index]
            and p_index in c_z_t_p_tree[storage_c_index][z_index][t_index]
        ):
            res_level = self.res_levels[resolution_level]
            return res_level.check_ifd(
                channel_index=storage_c_index, z_index=z_index, t_index=t_index, pos_index=p_index
            )
        return False

    def read_image(
        self,
        channel=None,
        z=None,
        time=None,
        position=None,
        channel_name=None,
        read_metadata=False,
        resolution_level=0,
        row=None,
        col=None,
        memmapped=False,
        **kwargs
    ):
        """
        Read image data as numpy array

        Parameters
        ----------
        channel : int
            index of the channel, if applicable (Default value = None)
        z : int
            index of z slice, if applicable (Default value = None)
        time : int
            index of the time point, if applicable (Default value = None)
        position : int
            index of the XY position, if applicable (Default value = None)
        channel_name :
            Name of the channel. Overrides channel index if supplied (Default value = None)
        row : int
            index of tile row for XY tiled datasets (Default value = None)
        col : int
            index of tile col for XY tiled datasets (Default value = None)
        resolution_level :
            0 is full resolution, otherwise represents downampling of pixels
            at 2 ** (resolution_level) (Default value = 0)
        read_metadata : bool
             (Default value = False)
        memmapped : bool
             (Default value = False)
        **kwargs :
            names and integer positions of any other axes

        Returns
        -------
        image : numpy array or tuple
            image as a 2D numpy array, or tuple with image and image metadata as dict

        """
        if channel is not None:
            kwargs["channel"] = channel
        if z is not None:
            kwargs["z"] = z
        if time is not None:
            kwargs["time"] = time
        if position is not None:
            kwargs["position"] = position

        if self._remote_storage is not None:
            if memmapped:
                raise Exception("Memory mapping not available for in progress acquisitions")
            axes = self._bridge.construct_java_object("java.util.HashMap")
            for key in kwargs.keys():
                axes.put(key, kwargs[key])
            if not self._remote_storage.has_image(axes, resolution_level):
                return None
            if row is not None and col is not None:
                tagged_image = self._remote_storage.get_tile_by_row_col(
                    axes, resolution_level, row, col
                )
            else:
                tagged_image = self._remote_storage.get_image(axes, resolution_level)
            if tagged_image is None:
                return None
            if resolution_level == 0:
                image = np.reshape(
                    tagged_image.pix,
                    newshape=[tagged_image.tags["Height"], tagged_image.tags["Width"]],
                )
                if (self._tile_height is not None) and (self._tile_width is not None):
                    # crop down to just the part that shows (i.e. no overlap)
                    image = image[
                        (image.shape[0] - self._tile_height)
                        // 2 : -(image.shape[0] - self._tile_height)
                        // 2,
                        (image.shape[1] - self._tile_width)
                        // 2 : -(image.shape[1] - self._tile_width)
                        // 2,
                    ]
            else:
                image = np.reshape(tagged_image.pix, newshape=[self._tile_height, self._tile_width])
            if read_metadata:
                return image, tagged_image.tags
            return image

        if row is not None or col is not None:
            raise Exception("row col lookup not yet implmented for saved datasets")
            # self.row_col_array #TODO: find position index in here

        storage_c_index, t_index, p_index, z_index = self._convert_to_storage_axes(
            kwargs, channel_name=channel_name
        )
        res_level = self.res_levels[resolution_level]
        return res_level.read_image(
            storage_c_index, z_index, t_index, p_index, read_metadata, memmapped
        )

    def read_first_image_metadata(self):
        """
        Get the first image metadata in the dataset (according to position along axes).
        This is useful if you want to access the image metadata in a dataset sparse, nonzero azes

        Returns
        -------
        metadata : dict

        """
        cztp_tree = self.res_levels[0].reader_tree
        c = list(cztp_tree.keys())[0]
        z = list(cztp_tree[c].keys())[0]
        t = list(cztp_tree[c][z].keys())[0]
        p = list(cztp_tree[c][z][t].keys())[0]
        return self.res_levels[0].read_metadata(c, z, t, p)

    def read_metadata(
        self,
        channel=None,
        z=None,
        time=None,
        position=None,
        channel_name=None,
        row=None,
        col=None,
        resolution_level=0,
        **kwargs
    ):
        """
        Read metadata only. Faster than using read_image to retrieve metadata

        Parameters
        ----------
        channel : int
            index of the channel, if applicable (Default value = None)
        z : int
            index of z slice, if applicable (Default value = None)
        time : int
            index of the time point, if applicable (Default value = None)
        position : int
            index of the XY position, if applicable (Default value = None)
        channel_name :
            Name of the channel. Overrides channel index if supplied (Default value = None)
        row : int
            index of tile row for XY tiled datasets (Default value = None)
        col : int
            index of tile col for XY tiled datasets (Default value = None)
        resolution_level :
            0 is full resolution, otherwise represents downampling of pixels
            at 2 ** (resolution_level) (Default value = 0)
        **kwargs :
            names and integer positions of any other axes

        Returns
        -------
        metadata : dict

        """
        if channel is not None:
            kwargs["channel"] = channel
        if z is not None:
            kwargs["z"] = z
        if time is not None:
            kwargs["time"] = time
        if position is not None:
            kwargs["position"] = position

        if self._remote_storage is not None:
            # read the tagged image because no funciton in Java API rn for metadata only
            return self.read_image(
                channel=channel,
                z=z,
                time=time,
                position=position,
                channel_name=channel_name,
                read_metadata=True,
                resolution_level=resolution_level,
                row=row,
                col=col,
                **kwargs
            )[1]

        storage_c_index, t_index, p_index, z_index = self._convert_to_storage_axes(
            kwargs, channel_name=channel_name
        )
        res_level = self.res_levels[resolution_level]
        return res_level.read_metadata(storage_c_index, z_index, t_index, p_index)

    def close(self):
        if self._remote_storage is not None:
            # nothing to do, this is handled on the java side
            return
        for res_level in self.res_levels:
            res_level.close()

    def get_channel_names(self):
        if self._remote_storage is not None:
            raise Exception("Not implemented for in progress datasets")
        return self._channel_names.keys()
Exemple #13
0
class Dataset:
    """
    Class that opens a single NDTiffStorage dataset
    """

    _POSITION_AXIS = 'position'
    _Z_AXIS = 'z'
    _TIME_AXIS = 'time'
    _CHANNEL_AXIS = 'channel'

    def __init__(self,
                 dataset_path=None,
                 full_res_only=True,
                 remote_storage=None):
        if remote_storage is not None:
            #this dataset is a view of an active acquisiiton. The storage exists on the java side
            self._remote_storage = remote_storage
            self._bridge = Bridge()
            smd = self._remote_storage.get_summary_metadata()
            self._tile_width = smd['Width'] - smd['GridPixelOverlapX']
            self._tile_height = smd['Height'] - smd['GridPixelOverlapY']
            return
        else:
            self._remote_storage = None

        self.path = dataset_path
        res_dirs = [
            dI for dI in os.listdir(dataset_path)
            if os.path.isdir(os.path.join(dataset_path, dI))
        ]
        # map from downsample factor to datset
        self.res_levels = {}
        if 'Full resolution' not in res_dirs:
            raise Exception(
                'Couldn\'t find full resolution directory. Is this the correct path to a dataset?'
            )
        num_tiffs = 0
        count = 0
        for res_dir in res_dirs:
            for file in os.listdir(os.path.join(dataset_path, res_dir)):
                if file.endswith('.tif'):
                    num_tiffs += 1
        for res_dir in res_dirs:
            if full_res_only and res_dir != 'Full resolution':
                continue
            res_dir_path = os.path.join(dataset_path, res_dir)
            res_level = _ResolutionLevel(res_dir_path, count, num_tiffs)
            if res_dir == 'Full resolution':
                #TODO: might want to move this within the resolution level class to facilitate loading pyramids
                self.res_levels[0] = res_level
                # get summary metadata and index tree from full resolution image
                self.summary_metadata = res_level.reader_list[0].summary_md
                self._channel_names = {}  #read them from image metadata
                self._extra_axes_to_storage_channel = {}

                # store some fields explicitly for easy access
                self.dtype = np.uint16 if self.summary_metadata[
                    'PixelType'] == 'GRAY16' else np.uint8
                self.pixel_size_xy_um = self.summary_metadata['PixelSize_um']
                self.pixel_size_z_um = self.summary_metadata[
                    'z-step_um'] if 'z-step_um' in self.summary_metadata else None
                self.image_width = res_level.reader_list[0].width
                self.image_height = res_level.reader_list[0].height
                self.overlap = np.array([
                    self.summary_metadata['GridPixelOverlapY'],
                    self.summary_metadata['GridPixelOverlapX']
                ]) if 'GridPixelOverlapY' in self.summary_metadata else None
                c_z_t_p_tree = res_level.reader_tree
                #the c here refers to super channels, encompassing all non-tzp axes in addition to channels
                # map of axis names to values where data exists
                self.axes = {
                    self._Z_AXIS: set(),
                    self._TIME_AXIS: set(),
                    self._POSITION_AXIS: set(),
                    self._CHANNEL_AXIS: set()
                }
                for c in c_z_t_p_tree.keys():
                    for z in c_z_t_p_tree[c]:
                        self.axes[self._Z_AXIS].add(z)
                        for t in c_z_t_p_tree[c][z]:
                            self.axes[self._TIME_AXIS].add(t)
                            for p in c_z_t_p_tree[c][z][t]:
                                self.axes[self._POSITION_AXIS].add(p)
                                if c not in self.axes['channel']:
                                    metadata = self.res_levels[
                                        0].read_metadata(channel_index=c,
                                                         z_index=z,
                                                         t_index=t,
                                                         pos_index=p)
                                    current_axes = metadata['Axes']
                                    non_zpt_axes = {}
                                    for axis in current_axes:
                                        if axis not in [
                                                self._Z_AXIS, self._TIME_AXIS,
                                                self._POSITION_AXIS
                                        ]:
                                            if axis not in self.axes:
                                                self.axes[axis] = set()
                                            self.axes[axis].add(
                                                current_axes[axis])
                                            non_zpt_axes[axis] = current_axes[
                                                axis]

                                    self._channel_names[
                                        metadata['Channel']] = non_zpt_axes[
                                            self._CHANNEL_AXIS]
                                    self._extra_axes_to_storage_channel[
                                        frozenset(non_zpt_axes.items())] = c

                #remove axes with no variation
                single_axes = [
                    axis for axis in self.axes if len(self.axes[axis]) == 1
                ]
                for axis in single_axes:
                    del self.axes[axis]

                if 'position' in self.axes and 'GridPixelOverlapX' in self.summary_metadata:
                    #Make an n x 2 array with nan's where no positions actually exist
                    row_cols = []
                    positions_checked = []
                    for c_index in c_z_t_p_tree.keys():
                        for z_index in c_z_t_p_tree[c_index].keys():
                            for t_index in c_z_t_p_tree[c_index][z_index].keys(
                            ):
                                p_indices = c_z_t_p_tree[c_index][z_index][
                                    t_index].keys()
                                for p_index in range(max(p_indices) + 1):
                                    if p_index in positions_checked:
                                        continue
                                    if p_index not in p_indices:
                                        row_cols.append(
                                            np.array([np.nan, np.nan]))
                                    elif not res_level.check_ifd(
                                            channel_index=c_index,
                                            z_index=z_index,
                                            t_index=t_index,
                                            pos_index=p_index):
                                        row_cols.append(
                                            np.array([
                                                np.nan, np.nan
                                            ]))  #this position is corrupted
                                        warnings.warn(
                                            'Corrupted image p: {} c: {} t: {} z: {}'
                                            .format(p_index, c_index, t_index,
                                                    z_index))
                                        row_cols.append(
                                            np.array([np.nan, np.nan]))
                                    else:
                                        md = res_level.read_metadata(
                                            channel_index=c_index,
                                            pos_index=p_index,
                                            t_index=t_index,
                                            z_index=z_index)
                                        row_cols.append(
                                            np.array([
                                                md['GridRowIndex'],
                                                md['GridColumnIndex']
                                            ]))
                                    positions_checked.append(p_index)
                    self.row_col_array = np.stack(row_cols)

            else:
                self.res_levels[int(np.log2(int(
                    res_dir.split('x')[1])))] = res_level
        print('\rDataset opened')

    def as_array(self, stitched=False):
        """
        Read all data image data as one big Dask array with last two axes as y, x and preceeding axes depending on data.
        The dask array is made up of memory-mapped numpy arrays, so the dataset does not need to be able to fit into RAM.
        If the data doesn't fully fill out the array (e.g. not every z-slice collected at every time point), zeros will
        be added automatically.

        To convert data into a numpy array, call np.asarray() on the returned result. However, doing so will bring the
        data into RAM, so it may be better to do this on only a slice of the array at a time.

        :param stitched: If true and tiles were acquired in a grid, lay out adjacent tiles next to one another
        :type stitched: boolean
        :return:
        """
        if self._remote_storage is not None:
            raise Exception(
                'Method not yet implemented for in progress acquisitions')
        self._empty_tile = np.zeros((self.image_height, self.image_width),
                                    self.dtype)
        self._count = 1
        total = np.prod([len(v) for v in self.axes.values()])

        def recurse_axes(loop_axes, point_axes):
            if len(loop_axes.values()) == 0:
                print('\rAdding data chunk {} of {}'.format(
                    self._count, total),
                      end='')
                self._count += 1
                if None not in point_axes.values() and self.has_image(
                        **point_axes):
                    return self.read_image(**point_axes, memmapped=True)
                else:
                    # return np.zeros((self.image_height, self.image_width), self.dtype)
                    return self._empty_tile
            else:
                #do position first because it makes stitching faster
                axis = 'position' if 'position' in loop_axes.keys(
                ) and stitched else list(loop_axes.keys())[0]
                remaining_axes = loop_axes.copy()
                del remaining_axes[axis]
                if axis == 'position' and stitched:
                    #Stitch tiles acquired in a grid
                    self.half_overlap = self.overlap[0] // 2

                    # get spatial layout of position indices
                    zero_min_row_col = (self.row_col_array -
                                        np.nanmin(self.row_col_array, axis=0))
                    row_col_mat = np.nan * np.ones([
                        int(np.nanmax(zero_min_row_col[:, 0])) + 1,
                        int(np.nanmax(zero_min_row_col[:, 1])) + 1
                    ])
                    positions_indices = np.array(list(loop_axes['position']))
                    rows = zero_min_row_col[positions_indices][:, 0]
                    cols = zero_min_row_col[positions_indices][:, 1]
                    # mask in case some positions were corrupted
                    mask = np.logical_not(np.isnan(rows))
                    row_col_mat[
                        rows[mask].astype(np.int),
                        cols[mask].astype(np.int)] = positions_indices[mask]

                    blocks = []
                    for row in row_col_mat:
                        blocks.append([])
                        for p_index in row:
                            print('\rAdding data chunk {} of {}'.format(
                                self._count, total),
                                  end='')
                            valed_axes = point_axes.copy()
                            valed_axes[axis] = int(
                                p_index) if not np.isnan(p_index) else None
                            blocks[-1].append(
                                da.stack(
                                    recurse_axes(remaining_axes, valed_axes)))

                    stitched_array = da.block(blocks)
                    return stitched_array
                else:
                    blocks = []
                    for val in loop_axes[axis]:
                        valed_axes = point_axes.copy()
                        valed_axes[axis] = val
                        blocks.append(recurse_axes(remaining_axes, valed_axes))
                    return blocks

        blocks = recurse_axes(self.axes, {})

        print('Stacking tiles')
        array = da.stack(blocks)
        print('\rDask array opened')
        return array

    def _convert_to_storage_axes(self, axes, channel_name=None):
        """
        Convert an abitrary set of axes to cztp axes as in the underlying storage

        :param axes:
        :return:
        """
        if channel_name is not None:
            if channel_name not in self._channel_names.keys():
                raise Exception(
                    'Channel name {} not found'.format(channel_name))
            axes[self._CHANNEL_AXIS] = self._channel_names[channel_name]
        if self._CHANNEL_AXIS not in axes:
            axes[self._CHANNEL_AXIS] = 0

        z_index = axes[self._Z_AXIS] if self._Z_AXIS in axes else 0
        t_index = axes[self._TIME_AXIS] if self._TIME_AXIS in axes else 0
        p_index = axes[
            self._POSITION_AXIS] if self._POSITION_AXIS in axes else 0

        non_zpt_axes = {
            key: axes[key]
            for key in axes.keys()
            if key not in [self._TIME_AXIS, self._POSITION_AXIS, self._Z_AXIS]
        }
        for axis in non_zpt_axes.keys():
            if axis not in self.axes.keys() and axis != 'channel':
                raise Exception('Unknown axis: {}'.format(axis))
        c_index = self._extra_axes_to_storage_channel[frozenset(
            non_zpt_axes.items())]
        return c_index, t_index, p_index, z_index

    def has_image(self,
                  channel=None,
                  z=None,
                  time=None,
                  position=None,
                  channel_name=None,
                  resolution_level=0,
                  row=None,
                  col=None,
                  **kwargs):
        """
        Check if this image is present in the dataset

        :param channel: index of the channel, if applicable
        :type channel: int
        :param z: index of z slice, if applicable
        :type z: int
        :param time: index of the time point, if applicable
        :type time: int
        :param position: index of the XY position, if applicable
        :type position: int
        :param channel_name: Name of the channel. Overrides channel index if supplied
        :type channel_name: str
        :param row: index of tile row for XY tiled datasets
        :type row: int
        :param col: index of tile col for XY tiled datasets
        :type col: int
        :param resolution_level: 0 is full resolution, otherwise represents downampling of pixels
            at 2 ** (resolution_level)
        :param kwargs: names and integer positions of any other axes
        :return: boolean indicating whether image present
        """
        if channel is not None:
            kwargs['channel'] = channel
        if z is not None:
            kwargs['z'] = z
        if time is not None:
            kwargs['time'] = time
        if position is not None:
            kwargs['position'] = position

        if self._remote_storage is not None:
            axes = self._bridge.construct_java_object('java.util.HashMap')
            for key in kwargs.keys():
                axes.put(key, kwargs[key])
            if row is not None and col is not None:
                return self._remote_storage.has_tile_by_row_col(
                    axes, resolution_level, row, col)
            else:
                return self._remote_storage.has_image(axes, resolution_level)

        if row is not None or col is not None:
            raise Exception(
                'row col lookup not yet implmented for saved datasets')
            # self.row_col_array #TODO: find position index in here

        storage_c_index, t_index, p_index, z_index = self._convert_to_storage_axes(
            kwargs, channel_name=channel_name)
        c_z_t_p_tree = self.res_levels[resolution_level].reader_tree
        if storage_c_index in c_z_t_p_tree and z_index in c_z_t_p_tree[storage_c_index] and  t_index in \
                c_z_t_p_tree[storage_c_index][z_index] and p_index in c_z_t_p_tree[storage_c_index][z_index][t_index]:
            res_level = self.res_levels[resolution_level]
            return res_level.check_ifd(channel_index=storage_c_index,
                                       z_index=z_index,
                                       t_index=t_index,
                                       pos_index=p_index)
        return False

    def read_image(self,
                   channel=None,
                   z=None,
                   time=None,
                   position=None,
                   channel_name=None,
                   read_metadata=False,
                   resolution_level=0,
                   row=None,
                   col=None,
                   memmapped=False,
                   **kwargs):
        """
        Read image data as numpy array

        :param channel: index of the channel, if applicable
        :type channel: int
        :param z: index of z slice, if applicable
        :type z: int
        :param time: index of the time point, if applicable
        :type time: int
        :param position: index of the XY position, if applicable
        :type position: int
        :param channel_name: Name of the channel. Overrides channel index if supplied
        :param row: index of tile row for XY tiled datasets
        :type row: int
        :param col: index of tile col for XY tiled datasets
        :type col: int
        :param resolution_level: 0 is full resolution, otherwise represents downampling of pixels
            at 2 ** (resolution_level)
        :param kwargs: names and integer positions of any other axes
        :return: image as 2D numpy array, or tuple with image and image metadata as dict
        """
        if channel is not None:
            kwargs['channel'] = channel
        if z is not None:
            kwargs['z'] = z
        if time is not None:
            kwargs['time'] = time
        if position is not None:
            kwargs['position'] = position

        if self._remote_storage is not None:
            if memmapped:
                raise Exception(
                    'Memory mapping not available for in progress acquisitions'
                )
            axes = self._bridge.construct_java_object('java.util.HashMap')
            for key in kwargs.keys():
                axes.put(key, kwargs[key])
            if not self._remote_storage.has_image(axes, resolution_level):
                return None
            if row is not None and col is not None:
                tagged_image = self._remote_storage.get_tile_by_row_col(
                    axes, resolution_level, row, col)
            else:
                tagged_image = self._remote_storage.get_image(
                    axes, resolution_level)
            if tagged_image is None:
                return None
            if resolution_level == 0:
                image = np.reshape(tagged_image.pix,
                                   newshape=[
                                       tagged_image.tags['Height'],
                                       tagged_image.tags['Width']
                                   ])
                #crop down to just the part that shows (i.e. no overlap)
                image = image[(image.shape[0] - self._tile_height) //
                              2:-(image.shape[0] - self._tile_height) // 2,
                              (image.shape[0] - self._tile_width) //
                              2:-(image.shape[0] - self._tile_width) // 2]
            else:
                image = np.reshape(
                    tagged_image.pix,
                    newshape=[self._tile_height, self._tile_width])
            if read_metadata:
                return image, tagged_image.tags
            return image

        if row is not None or col is not None:
            raise Exception(
                'row col lookup not yet implmented for saved datasets')
            # self.row_col_array #TODO: find position index in here

        storage_c_index, t_index, p_index, z_index = self._convert_to_storage_axes(
            kwargs, channel_name=channel_name)
        res_level = self.res_levels[resolution_level]
        return res_level.read_image(storage_c_index, z_index, t_index, p_index,
                                    read_metadata, memmapped)

    def read_metadata(self,
                      channel=None,
                      z=None,
                      time=None,
                      position=None,
                      channel_name=None,
                      row=None,
                      col=None,
                      resolution_level=0,
                      **kwargs):
        """
        Read metadata only. Faster than using read_image to retireve metadata

        :param channel: index of the channel, if applicable
        :type channel: int
        :param z: index of z slice, if applicable
        :type z: int
        :param time: index of the time point, if applicable
        :type time: int
        :param position: index of the XY position, if applicable
        :type position: int
        :param channel_name: Name of the channel. Overrides channel index if supplied
        :param row: index of tile row for XY tiled datasets
        :type row: int
        :param col: index of tile col for XY tiled datasets
        :type col: int
        :param resolution_level: 0 is full resolution, otherwise represents downampling of pixels
            at 2 ** (resolution_level)
        :param kwargs: names and integer positions of any other axes
        :return: image metadata as dict
        """
        if channel is not None:
            kwargs['channel'] = channel
        if z is not None:
            kwargs['z'] = z
        if time is not None:
            kwargs['time'] = time
        if position is not None:
            kwargs['position'] = position

        if self._remote_storage is not None:
            #read the tagged image because no funciton in Java API rn for metadata only
            return self.read_image(channel=channel,
                                   z=z,
                                   time=time,
                                   position=position,
                                   channel_name=channel_name,
                                   read_metadata=True,
                                   resolution_level=resolution_level,
                                   row=row,
                                   col=col,
                                   **kwargs)[1]

        storage_c_index, t_index, p_index, z_index = self._convert_to_storage_axes(
            kwargs, channel_name=channel_name)
        res_level = self.res_levels[resolution_level]
        return res_level.read_metadata(storage_c_index, z_index, t_index,
                                       p_index)

    def close(self):
        if self._remote_storage is not None:
            #nothing to do, this is handled on the java side
            return
        for res_level in self.res_levels:
            res_level.close()

    def get_channel_names(self):
        if self._remote_storage is not None:
            raise Exception('Not implemented for in progress datasets')
        return self._channel_names.keys()
Exemple #14
0
def start_headless(mm_app_path,
                   config_file,
                   java_loc=None,
                   core_log_path=None,
                   buffer_size_mb=1024):
    """
    Start a Java process that contains the neccessary libraries for pycro-manager to run,
    so that it can be run independently of the Micro-Manager GUI/application. This call
    will create and initialize MMCore with the configuration file provided.

    On windows plaforms, the Java Runtime Environment will be grabbed automatically
    as it is installed along with the Micro-Manager application.

    On non-windows platforms, it may need to be installed/specified manually in order to ensure compatibility.
    This can be checked by looking at the "maven.compiler.source" entry which the java parts of
    pycro-manager were compiled with. See here: https://github.com/micro-manager/pycro-manager/blob/29b584bfd71f0d05750f5d39600318902186a06a/java/pom.xml#L8

    Parameters
        ----------
        mm_app_path : str
            Path to top level folder of Micro-Manager installation (made with graphical installer)
        config_file : str
            Path to micro-manager config file, with which core will be initialized
        java_loc: str
            Path to the java version that it should be run with
        core_log_path : str
            Path to where core log files should be created
        buffer_size_mb : int
            Size of circular buffer in MB in MMCore
    """

    classpath = '"' + mm_app_path + '/plugins/Micro-Manager/*"'
    if java_loc is None:
        if platform.system() == "Windows":
            # windows comes with its own JRE
            java_loc = mm_app_path + "/jre/bin/javaw.exe"
        else:
            java_loc = "java"
    # This starts Java process and instantiates essential objects (core,
    # acquisition engine, ZMQServer)
    p = subprocess.Popen([
        java_loc,
        "-classpath",
        classpath,
        "-Dsun.java2d.dpiaware=false",
        "-Xmx2000m",
        # This is used by MM desktop app but breaks things on MacOS...Don't think its neccessary
        # "-XX:MaxDirectMemorySize=1000",
        "org.micromanager.remote.HeadlessLauncher",
    ])
    # make sure Java process cleans up when Python process exits
    atexit.register(lambda: p.terminate())

    # Initialize core
    bridge = Bridge()
    core = bridge.get_core()

    core.wait_for_system()
    core.load_system_configuration(config_file)

    core.set_circular_buffer_memory_footprint(buffer_size_mb)

    if core_log_path is not None:
        core.enable_stderr_log(True)
        core.enable_debug_log(True)
        core.set_primary_log_file(core_log_path)
Exemple #15
0
    def __init__(self,
                 dataset_path=None,
                 full_res_only=True,
                 remote_storage=None):
        self._tile_width = None
        self._tile_height = None
        if remote_storage is not None:
            # this dataset is a view of an active acquisiiton. The storage exists on the java side
            self._remote_storage = remote_storage
            self._bridge = Bridge()
            smd = self._remote_storage.get_summary_metadata()
            if "GridPixelOverlapX" in smd.keys():
                self._tile_width = smd["Width"] - smd["GridPixelOverlapX"]
                self._tile_height = smd["Height"] - smd["GridPixelOverlapY"]
            return
        else:
            self._remote_storage = None

        self.path = dataset_path
        res_dirs = [
            dI for dI in os.listdir(dataset_path)
            if os.path.isdir(os.path.join(dataset_path, dI))
        ]
        # map from downsample factor to datset
        self.res_levels = {}
        if "Full resolution" not in res_dirs:
            raise Exception(
                "Couldn't find full resolution directory. Is this the correct path to a dataset?"
            )
        num_tiffs = 0
        count = 0
        for res_dir in res_dirs:
            for file in os.listdir(os.path.join(dataset_path, res_dir)):
                if file.endswith(".tif"):
                    num_tiffs += 1
        for res_dir in res_dirs:
            if full_res_only and res_dir != "Full resolution":
                continue
            res_dir_path = os.path.join(dataset_path, res_dir)
            res_level = _ResolutionLevel(res_dir_path, count, num_tiffs)
            if res_dir == "Full resolution":
                # TODO: might want to move this within the resolution level class to facilitate loading pyramids
                self.res_levels[0] = res_level
                # get summary metadata and index tree from full resolution image
                self.summary_metadata = res_level.reader_list[0].summary_md
                self.rgb = res_level.reader_list[0].rgb
                self._channel_names = {}  # read them from image metadata
                self._extra_axes_to_storage_channel = {}

                # store some fields explicitly for easy access
                self.dtype = (np.uint16 if self.summary_metadata["PixelType"]
                              == "GRAY16" else np.uint8)
                self.pixel_size_xy_um = self.summary_metadata["PixelSize_um"]
                self.pixel_size_z_um = (self.summary_metadata["z-step_um"]
                                        if "z-step_um" in self.summary_metadata
                                        else None)
                self.image_width = res_level.reader_list[0].width
                self.image_height = res_level.reader_list[0].height
                self.overlap = (np.array([
                    self.summary_metadata["GridPixelOverlapY"],
                    self.summary_metadata["GridPixelOverlapX"],
                ]) if "GridPixelOverlapY" in self.summary_metadata else None)
                c_z_t_p_tree = res_level.reader_tree
                # the c here refers to super channels, encompassing all non-tzp axes in addition to channels
                # map of axis names to values where data exists
                self.axes = {
                    self._Z_AXIS: set(),
                    self._TIME_AXIS: set(),
                    self._POSITION_AXIS: set(),
                    self._CHANNEL_AXIS: set(),
                }

                # Need to map "super channels", which absorb all non channel/z/time/position axes to channel indices
                # used by underlying storage
                def parse_axes(current_axes, channel_index):
                    non_zpt_axes = {}
                    for axis_name in current_axes:
                        if axis_name not in [
                                self._Z_AXIS,
                                self._TIME_AXIS,
                                self._POSITION_AXIS,
                        ]:
                            if axis_name not in self.axes:
                                self.axes[axis_name] = set()
                            self.axes[axis_name].add(current_axes[axis_name])
                            non_zpt_axes[axis_name] = current_axes[axis_name]

                    self._extra_axes_to_storage_channel[frozenset(
                        non_zpt_axes.items())] = channel_index
                    return non_zpt_axes

                print("Parsing metadata\r", end="")
                if "Axes_metedata" in os.listdir(dataset_path):
                    # newer version with a metadata file where this is written explicitly
                    with open(
                            dataset_path +
                        (os.sep if dataset_path[-1] != os.sep else "") +
                            "Axes_metedata",
                            "rb",
                    ) as axes_metadata_file:
                        content = axes_metadata_file.read()
                    while len(content) > 0:
                        (flag, ) = struct.unpack("i", content[:4])
                        if flag == -1:
                            channel_index, length = struct.unpack(
                                "ii", content[4:12])
                            channel_name = content[12:12 +
                                                   length].decode("iso-8859-1")
                            # contains channel name metadata
                            self._channel_names[channel_name] = channel_index
                            content = content[12 + length:]
                        else:
                            channel_index = flag
                            (length, ) = struct.unpack("i", content[4:8])
                            # contains super channel metadata
                            other_axes = content[8:8 +
                                                 length].decode("iso-8859-1")
                            current_axes = {
                                axis_pos.split("_")[0]:
                                int(axis_pos.split("_")[1])
                                for axis_pos in other_axes.split("Axis_")
                                if len(axis_pos) > 0
                            }
                            parse_axes(current_axes, channel_index)
                            content = content[8 + length:]
                    # add standard time position z axes as well
                    for c in c_z_t_p_tree.keys():
                        for z in c_z_t_p_tree[c]:
                            self.axes[self._Z_AXIS].add(z)
                            for t in c_z_t_p_tree[c][z]:
                                self.axes[self._TIME_AXIS].add(t)
                                for p in c_z_t_p_tree[c][z][t]:
                                    self.axes[self._POSITION_AXIS].add(p)

                else:
                    # older version of NDTiffStorage, recover by brute force search through image metadata (slow)
                    for c in c_z_t_p_tree.keys():
                        for z in c_z_t_p_tree[c]:
                            self.axes[self._Z_AXIS].add(z)
                            for t in c_z_t_p_tree[c][z]:
                                self.axes[self._TIME_AXIS].add(t)
                                for p in c_z_t_p_tree[c][z][t]:
                                    self.axes[self._POSITION_AXIS].add(p)
                                    if c not in self.axes["channel"]:
                                        metadata = self.res_levels[
                                            0].read_metadata(channel_index=c,
                                                             z_index=z,
                                                             t_index=t,
                                                             pos_index=p)
                                        current_axes = metadata["Axes"]
                                        non_zpt_axes = parse_axes(
                                            current_axes, c)
                                        # make a map of channel names to channel indices
                                        self._channel_names[metadata[
                                            "Channel"]] = non_zpt_axes[
                                                self._CHANNEL_AXIS]
                print("Parsing metadata complete\r", end="")

                # remove axes with no variation
                single_axes = [
                    axis for axis in self.axes if len(self.axes[axis]) == 1
                ]
                for axis in single_axes:
                    del self.axes[axis]

                # If the dataset uses XY stitching, map out the row and col indices
                if "position" in self.axes and "GridPixelOverlapX" in self.summary_metadata:
                    # Make an n x 2 array with nan's where no positions actually exist
                    self.row_col_array = np.ones(
                        (len(self.axes["position"]), 2)) * np.nan
                    self.position_centers = np.ones(
                        (len(self.axes["position"]), 2)) * np.nan
                    row_cols = []
                    for c_index in c_z_t_p_tree.keys():
                        for z_index in c_z_t_p_tree[c_index].keys():
                            for t_index in c_z_t_p_tree[c_index][z_index].keys(
                            ):
                                p_indices = c_z_t_p_tree[c_index][z_index][
                                    t_index].keys()
                                for p_index in p_indices:
                                    # in case position index doesn't start at 0, pos_index_index is index
                                    # into self.axes['position']
                                    pos_index_index = list(
                                        self.axes["position"]).index(p_index)
                                    if not np.isnan(
                                            self.row_col_array[pos_index_index,
                                                               0]):
                                        # already figured this one out
                                        continue
                                    if not res_level.check_ifd(
                                            channel_index=c_index,
                                            z_index=z_index,
                                            t_index=t_index,
                                            pos_index=p_index,
                                    ):
                                        row_cols.append(
                                            np.array([
                                                np.nan, np.nan
                                            ]))  # this position is corrupted
                                        warnings.warn(
                                            "Corrupted image p: {} c: {} t: {} z: {}"
                                            .format(p_index, c_index, t_index,
                                                    z_index))
                                        row_cols.append(
                                            np.array([np.nan, np.nan]))
                                    else:
                                        md = res_level.read_metadata(
                                            channel_index=c_index,
                                            pos_index=p_index,
                                            t_index=t_index,
                                            z_index=z_index,
                                        )
                                        self.row_col_array[
                                            pos_index_index] = np.array([
                                                md["GridRowIndex"],
                                                md["GridColumnIndex"]
                                            ])
                                        self.position_centers[
                                            pos_index_index] = np.array([
                                                md["XPosition_um_Intended"],
                                                md["YPosition_um_Intended"],
                                            ])

            else:
                self.res_levels[int(np.log2(int(
                    res_dir.split("x")[1])))] = res_level

        if "GridPixelOverlapX" in self.summary_metadata:
            self._tile_width = (self.summary_metadata["Width"] -
                                self.summary_metadata["GridPixelOverlapX"])
            self._tile_height = (self.summary_metadata["Height"] -
                                 self.summary_metadata["GridPixelOverlapY"])
        else:
            self._tile_width = self.summary_metadata["Width"]
            self._tile_height = self.summary_metadata["Height"]

        print("\rDataset opened          ")
Exemple #16
0
    def __init__(self,
                 dataset_path=None,
                 full_res_only=True,
                 remote_storage_monitor=None):
        """
        Creat a Object providing access to and NDTiffStorage dataset, either one currently being acquired or one on disk

        Parameters
        ----------
        dataset_path : str
            Abosolute path of top level folder of a dataset on disk
        full_res_only : bool
            One open the full resolution data, if it is multi-res
        remote_storage_monitor : JavaObjectShadow
            Object that allows callbacks from remote NDTiffStorage
        """
        self._tile_width = None
        self._tile_height = None
        self._lock = threading.Lock()
        if remote_storage_monitor is not None:
            # this dataset is a view of an active acquisiiton. The storage exists on the java side
            self._remote_storage_monitor = remote_storage_monitor
            self._bridge = Bridge()
            self.summary_metadata = self._remote_storage_monitor.get_summary_metadata(
            )
            if "GridPixelOverlapX" in self.summary_metadata.keys():
                self._tile_width = (self.summary_metadata["Width"] -
                                    self.summary_metadata["GridPixelOverlapX"])
                self._tile_height = (
                    self.summary_metadata["Height"] -
                    self.summary_metadata["GridPixelOverlapY"])

            dataset_path = remote_storage_monitor.get_disk_location()
            dataset_path += "" if dataset_path[-1] == os.sep else os.sep
            full_res_path = dataset_path + "Full resolution"
            with self._lock:
                self.res_levels = {
                    0:
                    _ResolutionLevel(remote=True,
                                     summary_metadata=self.summary_metadata,
                                     path=full_res_path)
                }
            self.axes = {}
            return
        else:
            self._remote_storage_monitor = None

        self.path = dataset_path
        res_dirs = [
            dI for dI in os.listdir(dataset_path)
            if os.path.isdir(os.path.join(dataset_path, dI))
        ]
        # map from downsample factor to datset
        with self._lock:
            self.res_levels = {}
        if "Full resolution" not in res_dirs:
            raise Exception(
                "Couldn't find full resolution directory. Is this the correct path to a dataset?"
            )
        num_tiffs = 0
        count = 0
        for res_dir in res_dirs:
            for file in os.listdir(os.path.join(dataset_path, res_dir)):
                if file.endswith(".tif"):
                    num_tiffs += 1
        for res_dir in res_dirs:
            if full_res_only and res_dir != "Full resolution":
                continue
            res_dir_path = os.path.join(dataset_path, res_dir)
            res_level = _ResolutionLevel(res_dir_path, count, num_tiffs)
            if res_dir == "Full resolution":
                with self._lock:
                    self.res_levels[0] = res_level
                # get summary metadata and index tree from full resolution image
                self.summary_metadata = res_level.summary_metadata

                self.overlap = (np.array([
                    self.summary_metadata["GridPixelOverlapY"],
                    self.summary_metadata["GridPixelOverlapX"],
                ]) if "GridPixelOverlapY" in self.summary_metadata else None)

                self.axes = {}
                for axes_combo in res_level.index.keys():
                    for axis, position in axes_combo:
                        if axis not in self.axes.keys():
                            self.axes[axis] = set()
                        self.axes[axis].add(position)

                # figure out the mapping of channel name to position by reading image metadata
                print("\rReading channel names...", end="")
                self._read_channel_names()
                print("\rFinished reading channel names", end="")

                # remove axes with no variation
                # single_axes = [axis for axis in self.axes if len(self.axes[axis]) == 1]
                # for axis in single_axes:
                #     del self.axes[axis]

            else:
                with self._lock:
                    self.res_levels[int(np.log2(int(
                        res_dir.split("x")[1])))] = res_level

        # get information about image width and height, assuming that they are consistent for whole dataset
        # (which isn't strictly neccesary)
        with self._lock:
            first_index = list(self.res_levels[0].index.values())[0]
        if first_index["pixel_type"] == _MultipageTiffReader.EIGHT_BIT_RGB:
            self.bytes_per_pixel = 3
            self.dtype = np.uint8
        elif first_index["pixel_type"] == _MultipageTiffReader.EIGHT_BIT:
            self.bytes_per_pixel = 1
            self.dtype = np.uint8
        elif first_index["pixel_type"] == _MultipageTiffReader.SIXTEEN_BIT:
            self.bytes_per_pixel = 2
            self.dtype = np.uint16

        self.image_width = first_index["image_width"]
        self.image_height = first_index["image_height"]
        if "GridPixelOverlapX" in self.summary_metadata:
            self._tile_width = self.image_width - self.summary_metadata[
                "GridPixelOverlapX"]
            self._tile_height = self.image_height - self.summary_metadata[
                "GridPixelOverlapY"]

        print("\rDataset opened                ")
Exemple #17
0
def _processor_startup_fn(pull_port, push_port, sockets_connected_evt,
                          process_fn, event_queue, debug):
    """

    Parameters
    ----------
    pull_port :

    push_port :

    sockets_connected_evt :

    process_fn :

    event_queue :

    debug :


    Returns
    -------

    """
    bridge = Bridge(debug=debug)
    push_socket = bridge._connect_push(pull_port)
    pull_socket = bridge._connect_pull(push_port)
    if debug:
        print("image processing sockets connected")
    sockets_connected_evt.set()

    def process_and_sendoff(image_tags_tuple, original_dtype):
        """

        Parameters
        ----------
        image_tags_tuple :


        Returns
        -------

        """
        if len(image_tags_tuple) != 2:
            raise Exception(
                "If image is returned, it must be of the form (pixel, metadata)"
            )

        pixels = image_tags_tuple[0]
        metadata = image_tags_tuple[1]

        # only accepts same pixel type as original
        if not np.issubdtype(image_tags_tuple[0].dtype,
                             original_dtype) and not np.issubdtype(
                                 original_dtype, image_tags_tuple[0].dtype):
            raise Exception(
                "Processed image pixels must have same dtype as input image pixels, "
                "but instead they were {} and {}".format(
                    image_tags_tuple[0].dtype, pixels.dtype))

        metadata[
            "PixelType"] = "GRAY8" if pixels.dtype.itemsize == 1 else "GRAY16"

        processed_img = {
            "pixels": pixels.tobytes(),
            "metadata": metadata,
        }
        push_socket.send(processed_img)

    while True:
        message = None
        while message is None:
            message = pull_socket.receive(timeout=30)  # check for new message

        if "special" in message and message["special"] == "finished":
            push_socket.send(
                message)  # Continue propagating the finihsed signal
            push_socket.close()
            pull_socket.close()
            return

        metadata = message["metadata"]
        pixels = deserialize_array(message["pixels"])
        image = np.reshape(pixels, [metadata["Height"], metadata["Width"]])

        params = signature(process_fn).parameters
        if len(params) == 2 or len(params) == 4:
            processed = None
            try:
                if len(params) == 2:
                    processed = process_fn(image, metadata)
                elif len(params) == 4:
                    processed = process_fn(image, metadata, bridge,
                                           event_queue)
            except Exception as e:
                warnings.warn("exception in image processor: {}".format(e))
                continue
        else:
            raise Exception(
                "Incorrect number of arguments for image processing function, must be 2 or 4"
            )

        if processed is None:
            continue

        if type(processed) == list:
            for image in processed:
                process_and_sendoff(image, pixels.dtype)
        else:
            process_and_sendoff(processed, pixels.dtype)
Exemple #18
0
    def __init__(self,
                 dataset_path=None,
                 full_res_only=True,
                 remote_storage=None):
        if remote_storage is not None:
            #this dataset is a view of an active acquisiiton. The storage exists on the java side
            self._remote_storage = remote_storage
            self._bridge = Bridge()
            smd = self._remote_storage.get_summary_metadata()
            self._tile_width = smd['Width'] - smd['GridPixelOverlapX']
            self._tile_height = smd['Height'] - smd['GridPixelOverlapY']
            return
        else:
            self._remote_storage = None

        self.path = dataset_path
        res_dirs = [
            dI for dI in os.listdir(dataset_path)
            if os.path.isdir(os.path.join(dataset_path, dI))
        ]
        # map from downsample factor to datset
        self.res_levels = {}
        if 'Full resolution' not in res_dirs:
            raise Exception(
                'Couldn\'t find full resolution directory. Is this the correct path to a dataset?'
            )
        num_tiffs = 0
        count = 0
        for res_dir in res_dirs:
            for file in os.listdir(os.path.join(dataset_path, res_dir)):
                if file.endswith('.tif'):
                    num_tiffs += 1
        for res_dir in res_dirs:
            if full_res_only and res_dir != 'Full resolution':
                continue
            res_dir_path = os.path.join(dataset_path, res_dir)
            res_level = _ResolutionLevel(res_dir_path, count, num_tiffs)
            if res_dir == 'Full resolution':
                #TODO: might want to move this within the resolution level class to facilitate loading pyramids
                self.res_levels[0] = res_level
                # get summary metadata and index tree from full resolution image
                self.summary_metadata = res_level.reader_list[0].summary_md
                self._channel_names = {}  #read them from image metadata
                self._extra_axes_to_storage_channel = {}

                # store some fields explicitly for easy access
                self.dtype = np.uint16 if self.summary_metadata[
                    'PixelType'] == 'GRAY16' else np.uint8
                self.pixel_size_xy_um = self.summary_metadata['PixelSize_um']
                self.pixel_size_z_um = self.summary_metadata[
                    'z-step_um'] if 'z-step_um' in self.summary_metadata else None
                self.image_width = res_level.reader_list[0].width
                self.image_height = res_level.reader_list[0].height
                self.overlap = np.array([
                    self.summary_metadata['GridPixelOverlapY'],
                    self.summary_metadata['GridPixelOverlapX']
                ]) if 'GridPixelOverlapY' in self.summary_metadata else None
                c_z_t_p_tree = res_level.reader_tree
                #the c here refers to super channels, encompassing all non-tzp axes in addition to channels
                # map of axis names to values where data exists
                self.axes = {
                    self._Z_AXIS: set(),
                    self._TIME_AXIS: set(),
                    self._POSITION_AXIS: set(),
                    self._CHANNEL_AXIS: set()
                }
                for c in c_z_t_p_tree.keys():
                    for z in c_z_t_p_tree[c]:
                        self.axes[self._Z_AXIS].add(z)
                        for t in c_z_t_p_tree[c][z]:
                            self.axes[self._TIME_AXIS].add(t)
                            for p in c_z_t_p_tree[c][z][t]:
                                self.axes[self._POSITION_AXIS].add(p)
                                if c not in self.axes['channel']:
                                    metadata = self.res_levels[
                                        0].read_metadata(channel_index=c,
                                                         z_index=z,
                                                         t_index=t,
                                                         pos_index=p)
                                    current_axes = metadata['Axes']
                                    non_zpt_axes = {}
                                    for axis in current_axes:
                                        if axis not in [
                                                self._Z_AXIS, self._TIME_AXIS,
                                                self._POSITION_AXIS
                                        ]:
                                            if axis not in self.axes:
                                                self.axes[axis] = set()
                                            self.axes[axis].add(
                                                current_axes[axis])
                                            non_zpt_axes[axis] = current_axes[
                                                axis]

                                    self._channel_names[
                                        metadata['Channel']] = non_zpt_axes[
                                            self._CHANNEL_AXIS]
                                    self._extra_axes_to_storage_channel[
                                        frozenset(non_zpt_axes.items())] = c

                #remove axes with no variation
                single_axes = [
                    axis for axis in self.axes if len(self.axes[axis]) == 1
                ]
                for axis in single_axes:
                    del self.axes[axis]

                if 'position' in self.axes and 'GridPixelOverlapX' in self.summary_metadata:
                    #Make an n x 2 array with nan's where no positions actually exist
                    row_cols = []
                    positions_checked = []
                    for c_index in c_z_t_p_tree.keys():
                        for z_index in c_z_t_p_tree[c_index].keys():
                            for t_index in c_z_t_p_tree[c_index][z_index].keys(
                            ):
                                p_indices = c_z_t_p_tree[c_index][z_index][
                                    t_index].keys()
                                for p_index in range(max(p_indices) + 1):
                                    if p_index in positions_checked:
                                        continue
                                    if p_index not in p_indices:
                                        row_cols.append(
                                            np.array([np.nan, np.nan]))
                                    elif not res_level.check_ifd(
                                            channel_index=c_index,
                                            z_index=z_index,
                                            t_index=t_index,
                                            pos_index=p_index):
                                        row_cols.append(
                                            np.array([
                                                np.nan, np.nan
                                            ]))  #this position is corrupted
                                        warnings.warn(
                                            'Corrupted image p: {} c: {} t: {} z: {}'
                                            .format(p_index, c_index, t_index,
                                                    z_index))
                                        row_cols.append(
                                            np.array([np.nan, np.nan]))
                                    else:
                                        md = res_level.read_metadata(
                                            channel_index=c_index,
                                            pos_index=p_index,
                                            t_index=t_index,
                                            z_index=z_index)
                                        row_cols.append(
                                            np.array([
                                                md['GridRowIndex'],
                                                md['GridColumnIndex']
                                            ]))
                                    positions_checked.append(p_index)
                    self.row_col_array = np.stack(row_cols)

            else:
                self.res_levels[int(np.log2(int(
                    res_dir.split('x')[1])))] = res_level
        print('\rDataset opened')
Exemple #19
0
class Acquisition(object):
    """ """
    def __init__(
        self,
        directory=None,
        name=None,
        image_process_fn=None,
        event_generation_hook_fn=None,
        pre_hardware_hook_fn=None,
        post_hardware_hook_fn=None,
        post_camera_hook_fn=None,
        show_display=True,
        tile_overlap=None,
        max_multi_res_index=None,
        magellan_acq_index=None,
        magellan_explore=False,
        process=False,
        debug=False,
    ):
        """
        Parameters
        ----------
        directory : str
            saving directory for this acquisition. Required unless an image process function will be
            implemented that diverts images from saving
        name : str
            Saving name for the acquisition. Required unless an image process function will be
            implemented that diverts images from saving
        image_process_fn : Callable
            image processing function that will be called on each image that gets acquired.
            Can either take two arguments (image, metadata) where image is a numpy array and metadata is a dict
            containing the corresponding iamge metadata. Or a 4 argument version is accepted, which accepts (image,
            metadata, bridge, queue), where bridge and queue are an instance of the pycromanager.acquire.Bridge
            object for the purposes of interacting with arbitrary code on the Java side (such as the micro-manager
            core), and queue is a Queue objects that holds upcomning acquisition events. Both version must either
            return
        event_generation_hook_fn : Callable
            hook function that will as soon as acquisition events are generated (before hardware sequencing optimization
            in the acquisition engine. This is useful if one wants to modify acquisition events that they didn't generate
            (e.g. those generated by a GUI application). Accepts either one argument (the current acquisition event)
            or three arguments (current event, bridge, event Queue)
        pre_hardware_hook_fn : Callable
            hook function that will be run just before the hardware is updated before acquiring
            a new image. In the case of hardware sequencing, it will be run just before a sequence of instructions are
            dispatched to the hardware. Accepts either one argument (the current acquisition event) or three arguments
            (current event, bridge, event Queue)
        post_hardware_hook_fn : Callable
            hook function that will be run just before the hardware is updated before acquiring
            a new image. In the case of hardware sequencing, it will be run just after a sequence of instructions are
            dispatched to the hardware, but before the camera sequence has been started. Accepts either one argument
            (the current acquisition event) or three arguments (current event, bridge, event Queue)
        post_camera_hook_fn : Callable
            hook function that will be run just after the camera has been triggered to snapImage or
            startSequence. A common use case for this hook is when one want to send TTL triggers to the camera from an
            external timing device that synchronizes with other hardware. Accepts either one argument (the current
            acquisition event) or three arguments (current event, bridge, event Queue)
        tile_overlap : int or tuple of int
            If given, XY tiles will be laid out in a grid and multi-resolution saving will be
            actived. Argument can be a two element tuple describing the pixel overlaps between adjacent
            tiles. i.e. (pixel_overlap_x, pixel_overlap_y), or an integer to use the same overlap for both.
            For these features to work, the current hardware configuration must have a valid affine transform
            between camera coordinates and XY stage coordinates
        max_multi_res_index : int
            Maximum index to downsample to in multi-res pyramid mode (which is only active if a value for
            "tile_overlap" is passed in, or if running a Micro-Magellan acquisition). 0 is no downsampling,
            1 is downsampled up to 2x, 2 is downsampled up to 4x, etc. If not provided, it will be dynamically
            calculated and updated from data
        show_display : bool
            show the image viewer window
        magellan_acq_index : int
            run this acquisition using the settings specified at this position in the main
            GUI of micro-magellan (micro-manager plugin). This index starts at 0
        magellan_explore : bool
            Run a Micro-magellan explore acquisition
        process : bool
            Use multiprocessing instead of multithreading for acquisition hooks and image
            processors. This can be used to speed up CPU-bounded processing by eliminating bottlenecks
            caused by Python's Global Interpreter Lock, but also creates complications on Windows-based
            systems
        debug : bool
            whether to print debug messages
        """
        self.bridge = Bridge(debug=debug)
        self._debug = debug
        self._dataset = None

        if directory is not None:
            # Expend ~ in path
            directory = os.path.expanduser(directory)
            # If path is relative, retain knowledge of the current working directory
            directory = os.path.abspath(directory)

        if magellan_acq_index is not None:
            magellan_api = self.bridge.get_magellan()
            self._remote_acq = magellan_api.create_acquisition(
                magellan_acq_index)
            self._event_queue = None
        elif magellan_explore:
            magellan_api = self.bridge.get_magellan()
            self._remote_acq = magellan_api.create_explore_acquisition()
            self._event_queue = None
        else:
            # Create thread safe queue for events so they can be passed from multiple processes
            self._event_queue = multiprocessing.Queue(
            ) if process else queue.Queue()
            core = self.bridge.get_core()
            acq_factory = self.bridge.construct_java_object(
                "org.micromanager.remote.RemoteAcquisitionFactory",
                args=[core])

            show_viewer = show_display and (directory is not None
                                            and name is not None)
            if tile_overlap is None:
                # argument placeholders, these wont actually be used
                x_overlap = 0
                y_overlap = 0
            else:
                if type(tile_overlap) is tuple:
                    x_overlap, y_overlap = tile_overlap
                else:
                    x_overlap = tile_overlap
                    y_overlap = tile_overlap

            self._remote_acq = acq_factory.create_acquisition(
                directory,
                name,
                show_viewer,
                tile_overlap is not None,
                x_overlap,
                y_overlap,
                max_multi_res_index if max_multi_res_index is not None else -1,
            )
        storage = self._remote_acq.get_data_sink()
        if storage is not None:
            self.disk_location = storage.get_disk_location()

        if image_process_fn is not None:
            processor = self.bridge.construct_java_object(
                "org.micromanager.remote.RemoteImageProcessor")
            self._remote_acq.add_image_processor(processor)
            self._start_processor(processor,
                                  image_process_fn,
                                  self._event_queue,
                                  process=process)

        if event_generation_hook_fn is not None:
            hook = self.bridge.construct_java_object(
                "org.micromanager.remote.RemoteAcqHook",
                args=[self._remote_acq])
            self._start_hook(hook,
                             event_generation_hook_fn,
                             self._event_queue,
                             process=process)
            self._remote_acq.add_hook(hook,
                                      self._remote_acq.EVENT_GENERATION_HOOK)
        if pre_hardware_hook_fn is not None:
            hook = self.bridge.construct_java_object(
                "org.micromanager.remote.RemoteAcqHook",
                args=[self._remote_acq])
            self._start_hook(hook,
                             pre_hardware_hook_fn,
                             self._event_queue,
                             process=process)
            self._remote_acq.add_hook(hook,
                                      self._remote_acq.BEFORE_HARDWARE_HOOK)
        if post_hardware_hook_fn is not None:
            hook = self.bridge.construct_java_object(
                "org.micromanager.remote.RemoteAcqHook",
                args=[self._remote_acq])
            self._start_hook(hook,
                             post_hardware_hook_fn,
                             self._event_queue,
                             process=process)
            self._remote_acq.add_hook(hook,
                                      self._remote_acq.AFTER_HARDWARE_HOOK)
        if post_camera_hook_fn is not None:
            hook = self.bridge.construct_java_object(
                "org.micromanager.remote.RemoteAcqHook",
                args=[self._remote_acq])
            self._start_hook(hook,
                             post_camera_hook_fn,
                             self._event_queue,
                             process=process)
            self._remote_acq.add_hook(hook, self._remote_acq.AFTER_CAMERA_HOOK)

        self._remote_acq.start()

        if magellan_acq_index is None and not magellan_explore:
            self.event_port = self._remote_acq.get_event_port()

            self.event_process = threading.Thread(
                target=_event_sending_fn,
                args=(self.event_port, self._event_queue, self._debug),
                name="Event sending",
            )
            self.event_process.start()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self._event_queue is not None:  # magellan acquisitions dont have this
            # this should shut down storage and viewer as apporpriate
            self._event_queue.put(None)
        # now wait on it to finish
        self.await_completion()

    def get_disk_location(self):
        """
        Return the path where the dataset is on disk
        """
        return self._remote_acq.get_storage().get_disk_location()

    def get_dataset(self):
        """ """
        if self._dataset is None:
            self._dataset = Dataset(
                remote_storage=self._remote_acq.get_storage())
        return self._dataset

    def await_completion(self):
        """Wait for acquisition to finish and resources to be cleaned up"""
        while not self._remote_acq.is_finished():
            time.sleep(0.1)

    def acquire(self, events, keep_shutter_open=False):
        """Submit an event or a list of events for acquisition. Optimizations (i.e. taking advantage of
        hardware synchronization, where available), will take place across this list of events, but not
        over multiple calls of this method. A single event is a python dictionary with a specific structure

        Parameters
        ----------
        events

        keep_shutter_open :
             (Default value = False)

        Returns
        -------


        """
        if keep_shutter_open and isinstance(events, list):
            for e in events:
                e["keep_shutter_open"] = True
            events.append({"keep_shutter_open": False
                           })  # return to autoshutter, dont acquire an image
        elif keep_shutter_open and isinstance(events, dict):
            events["keep_shutter_open"] = True
            events = [
                events,
                {
                    "keep_shutter_open": False
                },
            ]  # return to autoshutter, dont acquire an image
        self._event_queue.put(events)

    def _start_hook(self, remote_hook, remote_hook_fn, event_queue, process):
        """

        Parameters
        ----------
        remote_hook :

        remote_hook_fn :

        event_queue :

        process :


        Returns
        -------

        """
        hook_connected_evt = multiprocessing.Event(
        ) if process else threading.Event()

        pull_port = remote_hook.get_pull_port()
        push_port = remote_hook.get_push_port()

        hook_thread = (multiprocessing.Process
                       if process else threading.Thread)(
                           target=_acq_hook_startup_fn,
                           name="AcquisitionHook",
                           args=(
                               pull_port,
                               push_port,
                               hook_connected_evt,
                               event_queue,
                               remote_hook_fn,
                               self._debug,
                           ),
                       )
        # if process else threading.Thread(target=_acq_hook_fn, args=(), name='AcquisitionHook')
        hook_thread.start()

        hook_connected_evt.wait()  # wait for push/pull sockets to connect

    def _start_processor(self, processor, process_fn, event_queue, process):
        """

        Parameters
        ----------
        processor :

        process_fn :

        event_queue :

        process :


        Returns
        -------

        """
        # this must start first
        processor.start_pull()

        sockets_connected_evt = multiprocessing.Event(
        ) if process else threading.Event()

        pull_port = processor.get_pull_port()
        push_port = processor.get_push_port()

        self.processor_thread = (multiprocessing.Process
                                 if process else threading.Thread)(
                                     target=_processor_startup_fn,
                                     args=(
                                         pull_port,
                                         push_port,
                                         sockets_connected_evt,
                                         process_fn,
                                         event_queue,
                                         self._debug,
                                     ),
                                     name="ImageProcessor",
                                 )
        self.processor_thread.start()

        sockets_connected_evt.wait()  # wait for push/pull sockets to connect
        processor.start_push()
Exemple #20
0
    def __init__(self, dataset_path=None, full_res_only=True, remote_storage=None):
        self._tile_width = None
        self._tile_height = None
        if remote_storage is not None:
            # this dataset is a view of an active acquisiiton. The storage exists on the java side
            self._remote_storage = remote_storage
            self._bridge = Bridge()
            smd = self._remote_storage.get_summary_metadata()
            if "GridPixelOverlapX" in smd.keys():
                self._tile_width = smd["Width"] - smd["GridPixelOverlapX"]
                self._tile_height = smd["Height"] - smd["GridPixelOverlapY"]
            return
        else:
            self._remote_storage = None

        self.path = dataset_path
        res_dirs = [
            dI for dI in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, dI))
        ]
        # map from downsample factor to datset
        self.res_levels = {}
        if "Full resolution" not in res_dirs:
            raise Exception(
                "Couldn't find full resolution directory. Is this the correct path to a dataset?"
            )
        num_tiffs = 0
        count = 0
        for res_dir in res_dirs:
            for file in os.listdir(os.path.join(dataset_path, res_dir)):
                if file.endswith(".tif"):
                    num_tiffs += 1
        for res_dir in res_dirs:
            if full_res_only and res_dir != "Full resolution":
                continue
            res_dir_path = os.path.join(dataset_path, res_dir)
            res_level = _ResolutionLevel(res_dir_path, count, num_tiffs)
            if res_dir == "Full resolution":
                # TODO: might want to move this within the resolution level class to facilitate loading pyramids
                self.res_levels[0] = res_level
                # get summary metadata and index tree from full resolution image
                self.summary_metadata = res_level.reader_list[0].summary_md
                self.rgb = res_level.reader_list[0].rgb
                self._channel_names = {}  # read them from image metadata
                self._extra_axes_to_storage_channel = {}

                # store some fields explicitly for easy access
                self.dtype = (
                    np.uint16 if self.summary_metadata["PixelType"] == "GRAY16" else np.uint8
                )
                self.pixel_size_xy_um = self.summary_metadata["PixelSize_um"]
                self.pixel_size_z_um = (
                    self.summary_metadata["z-step_um"]
                    if "z-step_um" in self.summary_metadata
                    else None
                )
                self.image_width = res_level.reader_list[0].width
                self.image_height = res_level.reader_list[0].height
                self.overlap = (
                    np.array(
                        [
                            self.summary_metadata["GridPixelOverlapY"],
                            self.summary_metadata["GridPixelOverlapX"],
                        ]
                    )
                    if "GridPixelOverlapY" in self.summary_metadata
                    else None
                )
                c_z_t_p_tree = res_level.reader_tree
                # the c here refers to super channels, encompassing all non-tzp axes in addition to channels
                # map of axis names to values where data exists
                self.axes = {
                    self._Z_AXIS: set(),
                    self._TIME_AXIS: set(),
                    self._POSITION_AXIS: set(),
                    self._CHANNEL_AXIS: set(),
                }
                for c in c_z_t_p_tree.keys():
                    for z in c_z_t_p_tree[c]:
                        self.axes[self._Z_AXIS].add(z)
                        for t in c_z_t_p_tree[c][z]:
                            self.axes[self._TIME_AXIS].add(t)
                            for p in c_z_t_p_tree[c][z][t]:
                                self.axes[self._POSITION_AXIS].add(p)
                                if c not in self.axes["channel"]:
                                    metadata = self.res_levels[0].read_metadata(
                                        channel_index=c, z_index=z, t_index=t, pos_index=p
                                    )
                                    current_axes = metadata["Axes"]
                                    non_zpt_axes = {}
                                    for axis in current_axes:
                                        if axis not in [
                                            self._Z_AXIS,
                                            self._TIME_AXIS,
                                            self._POSITION_AXIS,
                                        ]:
                                            if axis not in self.axes:
                                                self.axes[axis] = set()
                                            self.axes[axis].add(current_axes[axis])
                                            non_zpt_axes[axis] = current_axes[axis]

                                    self._channel_names[metadata["Channel"]] = non_zpt_axes[
                                        self._CHANNEL_AXIS
                                    ]
                                    self._extra_axes_to_storage_channel[
                                        frozenset(non_zpt_axes.items())
                                    ] = c

                # remove axes with no variation
                single_axes = [axis for axis in self.axes if len(self.axes[axis]) == 1]
                for axis in single_axes:
                    del self.axes[axis]

                if "position" in self.axes and "GridPixelOverlapX" in self.summary_metadata:
                    # Make an n x 2 array with nan's where no positions actually exist
                    self.row_col_array = np.ones((len(self.axes["position"]), 2)) * np.nan
                    self.position_centers = np.ones((len(self.axes["position"]), 2)) * np.nan
                    row_cols = []
                    for c_index in c_z_t_p_tree.keys():
                        for z_index in c_z_t_p_tree[c_index].keys():
                            for t_index in c_z_t_p_tree[c_index][z_index].keys():
                                p_indices = c_z_t_p_tree[c_index][z_index][t_index].keys()
                                for p_index in p_indices:
                                    # in case position index doesn't start at 0, pos_index_index is index
                                    # into self.axes['position']
                                    pos_index_index = list(self.axes["position"]).index(p_index)
                                    if not np.isnan(self.row_col_array[pos_index_index, 0]):
                                        # already figured this one out
                                        continue
                                    if not res_level.check_ifd(
                                        channel_index=c_index,
                                        z_index=z_index,
                                        t_index=t_index,
                                        pos_index=p_index,
                                    ):
                                        row_cols.append(
                                            np.array([np.nan, np.nan])
                                        )  # this position is corrupted
                                        warnings.warn(
                                            "Corrupted image p: {} c: {} t: {} z: {}".format(
                                                p_index, c_index, t_index, z_index
                                            )
                                        )
                                        row_cols.append(np.array([np.nan, np.nan]))
                                    else:
                                        md = res_level.read_metadata(
                                            channel_index=c_index,
                                            pos_index=p_index,
                                            t_index=t_index,
                                            z_index=z_index,
                                        )
                                        self.row_col_array[pos_index_index] = np.array(
                                            [md["GridRowIndex"], md["GridColumnIndex"]]
                                        )
                                        self.position_centers[pos_index_index] = np.array(
                                            [
                                                md["XPosition_um_Intended"],
                                                md["YPosition_um_Intended"],
                                            ]
                                        )

            else:
                self.res_levels[int(np.log2(int(res_dir.split("x")[1])))] = res_level
        print("\rDataset opened")
Exemple #21
0
class Dataset:
    """Class that opens a single NDTiffStorage dataset"""

    _POSITION_AXIS = "position"
    _ROW_AXIS = "roq"
    _COLUMN_AXIS = "column"
    _Z_AXIS = "z"
    _TIME_AXIS = "time"
    _CHANNEL_AXIS = "channel"

    def __new__(cls, dataset_path=None, full_res_only=True, remote_storage=None):
        if dataset_path is None:
            return super(Dataset, cls).__new__(Dataset)
        # Search for Full resolution dir, check for index
        res_dirs = [
            dI for dI in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, dI))
        ]
        if "Full resolution" not in res_dirs:
            raise Exception(
                "Couldn't find full resolution directory. Is this the correct path to a dataset?"
            )
        fullres_path = (
            dataset_path + ("" if dataset_path[-1] == os.sep else os.sep) + "Full resolution"
        )
        if "NDTiff.index" in os.listdir(fullres_path):
            return super(Dataset, cls).__new__(Dataset)
        else:
            obj = Legacy_NDTiff_Dataset.__new__(Legacy_NDTiff_Dataset)
            obj.__init__(dataset_path, full_res_only, remote_storage)
            return obj

    def __init__(self, dataset_path=None, full_res_only=True, remote_storage=None):
        self._tile_width = None
        self._tile_height = None
        if remote_storage is not None:
            # this dataset is a view of an active acquisiiton. The storage exists on the java side
            self._remote_storage = remote_storage
            self._bridge = Bridge()
            smd = self._remote_storage.get_summary_metadata()
            if "GridPixelOverlapX" in smd.keys():
                self._tile_width = smd["Width"] - smd["GridPixelOverlapX"]
                self._tile_height = smd["Height"] - smd["GridPixelOverlapY"]
            return
        else:
            self._remote_storage = None

        self.path = dataset_path
        res_dirs = [
            dI for dI in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, dI))
        ]
        # map from downsample factor to datset
        self.res_levels = {}
        if "Full resolution" not in res_dirs:
            raise Exception(
                "Couldn't find full resolution directory. Is this the correct path to a dataset?"
            )
        num_tiffs = 0
        count = 0
        for res_dir in res_dirs:
            for file in os.listdir(os.path.join(dataset_path, res_dir)):
                if file.endswith(".tif"):
                    num_tiffs += 1
        for res_dir in res_dirs:
            if full_res_only and res_dir != "Full resolution":
                continue
            res_dir_path = os.path.join(dataset_path, res_dir)
            res_level = _ResolutionLevel(res_dir_path, count, num_tiffs)
            if res_dir == "Full resolution":
                self.res_levels[0] = res_level
                # get summary metadata and index tree from full resolution image
                self.summary_metadata = res_level.summary_metadata

                self.overlap = (
                    np.array(
                        [
                            self.summary_metadata["GridPixelOverlapY"],
                            self.summary_metadata["GridPixelOverlapX"],
                        ]
                    )
                    if "GridPixelOverlapY" in self.summary_metadata
                    else None
                )

                self.axes = {}
                for axes_combo in res_level.index.keys():
                    for axis, position in axes_combo:
                        if axis not in self.axes.keys():
                            self.axes[axis] = set()
                        self.axes[axis].add(position)

                # figure out the mapping of channel name to position by reading image metadata
                print("\rReading channel names...", end="")
                if self._CHANNEL_AXIS in self.axes.keys():
                    self._channel_names = {}
                    for key in res_level.index.keys():
                        axes = {axis: position for axis, position in key}
                        if (
                            self._CHANNEL_AXIS in axes.keys()
                            and axes[self._CHANNEL_AXIS] not in self._channel_names.values()
                        ):
                            channel_name = res_level.read_metadata(axes)["Channel"]
                            self._channel_names[channel_name] = axes[self._CHANNEL_AXIS]
                        if len(self._channel_names.values()) == len(self.axes[self._CHANNEL_AXIS]):
                            break
                print("\rFinished reading channel names", end="")

                # remove axes with no variation
                single_axes = [axis for axis in self.axes if len(self.axes[axis]) == 1]
                for axis in single_axes:
                    del self.axes[axis]

                # If the dataset uses XY stitching, map out the row and col indices
                if (
                    "TiledImageStorage" in self.summary_metadata
                    and self.summary_metadata["TiledImageStorage"]
                ):
                    # Make an n x 2 array with nan's where no positions actually exist
                    pass

            else:
                self.res_levels[int(np.log2(int(res_dir.split("x")[1])))] = res_level

        # get information about image width and height, assuming that they are consistent for whole dataset
        # (which isn't strictly neccesary)
        first_index = list(self.res_levels[0].index.values())[0]
        if first_index["pixel_type"] == _MultipageTiffReader.EIGHT_BIT_RGB:
            self.bytes_per_pixel = 3
            self.dtype = np.uint8
        elif first_index["pixel_type"] == _MultipageTiffReader.EIGHT_BIT:
            self.bytes_per_pixel = 1
            self.dtype = np.uint8
        elif first_index["pixel_type"] == _MultipageTiffReader.SIXTEEN_BIT:
            self.bytes_per_pixel = 2
            self.dtype = np.uint16

        self.image_width = first_index["image_width"]
        self.image_height = first_index["image_height"]
        if "GridPixelOverlapX" in self.summary_metadata:
            self._tile_width = self.image_width - self.summary_metadata["GridPixelOverlapX"]
            self._tile_height = self.image_height - self.summary_metadata["GridPixelOverlapY"]

        print("\rDataset opened                ")

    def as_array(self, stitched=False, verbose=True):
        """
        Read all data image data as one big Dask array with last two axes as y, x and preceeding axes depending on data.
        The dask array is made up of memory-mapped numpy arrays, so the dataset does not need to be able to fit into RAM.
        If the data doesn't fully fill out the array (e.g. not every z-slice collected at every time point), zeros will
        be added automatically.

        To convert data into a numpy array, call np.asarray() on the returned result. However, doing so will bring the
        data into RAM, so it may be better to do this on only a slice of the array at a time.

        Parameters
        ----------
        stitched : bool
            If true and tiles were acquired in a grid, lay out adjacent tiles next to one another (Default value = False)
        verbose : bool
            If True print updates on progress loading the image
        Returns
        -------
        dataset : dask array
        """
        if self._remote_storage is not None:
            raise Exception("Method not yet implemented for in progress acquisitions")

        w = self.image_height if not stitched else self._tile_width
        h = self.image_height if not stitched else self._tile_height
        self._empty_tile = (
            np.zeros((h, w), self.dtype)
            if self.bytes_per_pixel != 3
            else np.zeros((h, w, 3), self.dtype)
        )
        self._count = 1
        total = np.prod([len(v) for v in self.axes.values()])

        def recurse_axes(loop_axes, point_axes):
            if len(loop_axes.values()) == 0:
                if verbose:
                    print("\rAdding data chunk {} of {}".format(self._count, total), end="")
                self._count += 1
                if None not in point_axes.values() and self.has_image(**point_axes):
                    if stitched:
                        img = self.read_image(**point_axes, memmapped=True)
                        if self.half_overlap[0] != 0:
                            img = img[
                                self.half_overlap[0] : -self.half_overlap[0],
                                self.half_overlap[1] : -self.half_overlap[1],
                            ]
                        return img
                    else:
                        return self.read_image(**point_axes, memmapped=True)
                else:
                    # return np.zeros((self.image_height, self.image_width), self.dtype)
                    return self._empty_tile
            else:
                # do position first because it makes stitching faster
                axis = (
                    "position"
                    if "position" in loop_axes.keys() and stitched
                    else list(loop_axes.keys())[0]
                )
                remaining_axes = loop_axes.copy()
                del remaining_axes[axis]
                if axis == "position" and stitched:
                    # Stitch tiles acquired in a grid
                    self.half_overlap = (self.overlap[0] // 2, self.overlap[1] // 2)

                    # get spatial layout of position indices
                    zero_min_row_col = self.row_col_array - np.nanmin(self.row_col_array, axis=0)
                    row_col_mat = np.nan * np.ones(
                        [
                            int(np.nanmax(zero_min_row_col[:, 0])) + 1,
                            int(np.nanmax(zero_min_row_col[:, 1])) + 1,
                        ]
                    )
                    positions_indices = np.array(list(loop_axes["position"]))
                    rows = zero_min_row_col[positions_indices][:, 0]
                    cols = zero_min_row_col[positions_indices][:, 1]
                    # mask in case some positions were corrupted
                    mask = np.logical_not(np.isnan(rows))
                    row_col_mat[
                        rows[mask].astype(np.int), cols[mask].astype(np.int)
                    ] = positions_indices[mask]

                    blocks = []
                    for row in row_col_mat:
                        blocks.append([])
                        for p_index in row:
                            if verbose:
                                print(
                                    "\rAdding data chunk {} of {}".format(self._count, total),
                                    end="",
                                )
                            valed_axes = point_axes.copy()
                            valed_axes[axis] = int(p_index) if not np.isnan(p_index) else None
                            blocks[-1].append(da.stack(recurse_axes(remaining_axes, valed_axes)))

                    if self.rgb:
                        stitched_array = np.concatenate(
                            [
                                np.concatenate(row, axis=len(blocks[0][0].shape) - 2)
                                for row in blocks
                            ],
                            axis=len(blocks[0][0].shape) - 3,
                        )
                    else:
                        stitched_array = da.block(blocks)
                    return stitched_array
                else:
                    blocks = []
                    for val in loop_axes[axis]:
                        valed_axes = point_axes.copy()
                        valed_axes[axis] = val
                        blocks.append(recurse_axes(remaining_axes, valed_axes))
                    return blocks

        blocks = recurse_axes(self.axes, {})

        if verbose:
            print(
                "\rStacking tiles...         "
            )  # extra space otherwise there is no space after the "Adding data chunk {} {}"
        # import time
        # s = time.time()
        array = da.stack(blocks, allow_unknown_chunksizes=False)
        # e = time.time()
        # print(e - s)
        if verbose:
            print("\rDask array opened")
        return array

    def has_image(
        self,
        channel=0,
        z=None,
        time=None,
        position=None,
        channel_name=None,
        resolution_level=0,
        row=None,
        col=None,
        **kwargs
    ):
        """Check if this image is present in the dataset

        Parameters
        ----------
        channel : int
            index of the channel, if applicable (Default value = None)
        z : int
            index of z slice, if applicable (Default value = None)
        time : int
            index of the time point, if applicable (Default value = None)
        position : int
            index of the XY position, if applicable (Default value = None)
        channel_name : str
            Name of the channel. Overrides channel index if supplied (Default value = None)
        row : int
            index of tile row for XY tiled datasets (Default value = None)
        col : int
            index of tile col for XY tiled datasets (Default value = None)
        resolution_level :
            0 is full resolution, otherwise represents downampling of pixels
            at 2 ** (resolution_level) (Default value = 0)
        **kwargs
            Arbitrary keyword arguments

        Returns
        -------
        bool :
            indicating whether the dataset has an image matching the specifications
        """
        if self._remote_storage is not None:
            axes = self._bridge.construct_java_object("java.util.HashMap")
            for key in kwargs.keys():
                axes.put(key, kwargs[key])
            if row is not None and col is not None:
                return self._remote_storage.has_tile_by_row_col(axes, resolution_level, row, col)
            else:
                return self._remote_storage.has_image(axes, resolution_level)

        return self.res_levels[0].has_image(
            self._consolidate_axes(channel, channel_name, z, position, time, row, col, kwargs)
        )

    def read_image(
        self,
        channel=0,
        z=None,
        time=None,
        position=None,
        row=None,
        col=None,
        channel_name=None,
        resolution_level=0,
        memmapped=False,
        **kwargs
    ):
        """
        Read image data as numpy array

        Parameters
        ----------
        channel : int
            index of the channel, if applicable (Default value = None)
        z : int
            index of z slice, if applicable (Default value = None)
        time : int
            index of the time point, if applicable (Default value = None)
        position : int
            index of the XY position, if applicable (Default value = None)
        channel_name :
            Name of the channel. Overrides channel index if supplied (Default value = None)
        row : int
            index of tile row for XY tiled datasets (Default value = None)
        col : int
            index of tile col for XY tiled datasets (Default value = None)
        resolution_level :
            0 is full resolution, otherwise represents downampling of pixels
            at 2 ** (resolution_level) (Default value = 0)
        memmapped : bool
             (Default value = False)
        **kwargs :
            names and integer positions of any other axes

        Returns
        -------
        image : numpy array or tuple
            image as a 2D numpy array, or tuple with image and image metadata as dict

        """
        axes = self._consolidate_axes(channel, channel_name, z, position, time, row, col, kwargs)

        if self._remote_storage is not None:
            if memmapped:
                raise Exception("Memory mapping not available for in progress acquisitions")
            java_axes = self._bridge.construct_java_object("java.util.HashMap")
            for key in axes:
                java_axes.put(key, kwargs[key])
            if not self._remote_storage.has_image(java_axes, resolution_level):
                return None
            tagged_image = self._remote_storage.get_image(axes, resolution_level)
            if resolution_level == 0:
                image = np.reshape(
                    tagged_image.pix,
                    newshape=[tagged_image.tags["Height"], tagged_image.tags["Width"]],
                )
                if (self._tile_height is not None) and (self._tile_width is not None):
                    # crop down to just the part that shows (i.e. no overlap)
                    image = image[
                        (image.shape[0] - self._tile_height)
                        // 2 : -(image.shape[0] - self._tile_height)
                        // 2,
                        (image.shape[1] - self._tile_width)
                        // 2 : -(image.shape[1] - self._tile_width)
                        // 2,
                    ]
            else:
                image = np.reshape(tagged_image.pix, newshape=[self._tile_height, self._tile_width])
            return image
        else:
            res_level = self.res_levels[resolution_level]
            return res_level.read_image(axes, memmapped)

    def read_metadata(
        self,
        channel=0,
        z=None,
        time=None,
        position=None,
        channel_name=None,
        row=None,
        col=None,
        resolution_level=0,
        **kwargs
    ):
        """
        Read metadata only. Faster than using read_image to retrieve metadata

        Parameters
        ----------
        channel : int
            index of the channel, if applicable (Default value = None)
        z : int
            index of z slice, if applicable (Default value = None)
        time : int
            index of the time point, if applicable (Default value = None)
        position : int
            index of the XY position, if applicable (Default value = None)
        channel_name :
            Name of the channel. Overrides channel index if supplied (Default value = None)
        row : int
            index of tile row for XY tiled datasets (Default value = None)
        col : int
            index of tile col for XY tiled datasets (Default value = None)
        resolution_level :
            0 is full resolution, otherwise represents downampling of pixels
            at 2 ** (resolution_level) (Default value = 0)
        **kwargs :
            names and integer positions of any other axes

        Returns
        -------
        metadata : dict

        """
        axes = self._consolidate_axes(channel, channel_name, z, position, time, row, col, kwargs)

        if self._remote_storage is not None:
            java_axes = self._bridge.construct_java_object("java.util.HashMap")
            for key in axes:
                java_axes.put(key, kwargs[key])
            if not self._remote_storage.has_image(java_axes, resolution_level):
                return None
            # TODO: could speed this up a lot on the Java side by only reading metadata instead of pixels too
            return self._remote_storage.get_image(axes, resolution_level).tags

        else:
            res_level = self.res_levels[resolution_level]
            return res_level.read_metadata(axes)

    def close(self):
        if self._remote_storage is not None:
            # nothing to do, this is handled on the java side
            return
        for res_level in self.res_levels:
            res_level.close()

    def get_channel_names(self):
        if self._remote_storage is not None:
            raise Exception("Not implemented for in progress datasets")
        return self._channel_names.keys()

    def _consolidate_axes(self, channel, channel_name, z, position, time, row, col, kwargs):
        axes = {}
        if channel is not None:
            axes[self._CHANNEL_AXIS] = channel
        if channel_name is not None:
            axes[self._CHANNEL_AXIS] = self._channel_names[channel_name]
        if z is not None:
            axes[self._Z_AXIS] = z
        if position is not None:
            axes[self._POSITION_AXIS] = position
        if time is not None:
            axes[self._TIME_AXIS] = time
        if row is not None:
            axes[self._ROW_AXIS] = row
        if col is not None:
            axes[self._COLUMN_AXIS] = col
        for other_axis_name in kwargs.keys():
            axes[other_axis_name] = kwargs[other_axis_name]
        return axes