예제 #1
0
class CounterPart(Part):
    """Defines a counter `Attribute` with zero and increment `Method` objects"""
    def __init__(self, name):
        # type: (APartName) -> None
        super(CounterPart, self).__init__(name)
        # TODO: why doesn't this show up in the docs for CounterPart?
        self.counter = NumberMeta("float64",
                                  "The current value of the counter",
                                  tags=[config_tag(),
                                        Widget.TEXTINPUT.tag()
                                        ]).create_attribute_model()
        """Attribute holding the current counter value"""

    def setup(self, registrar):
        # type: (PartRegistrar) -> None
        # Add some Attribute and Methods to the Block
        registrar.add_attribute_model("counter", self.counter,
                                      self.counter.set_value)
        registrar.add_method_model(self.zero)
        registrar.add_method_model(self.increment)

    def zero(self):
        """Zero the counter attribute"""
        self.counter.set_value(0)

    def increment(self):
        """Add one to the counter attribute"""
        self.counter.set_value(self.counter.value + 1)
예제 #2
0
class CounterPart(Part):
    """Defines a counter `Attribute` with zero and increment `Method` objects"""

    #: Writeable Attribute holding the current counter value
    counter: Optional[AttributeModel] = None
    #: Writeable Attribute holding the amount to increment() by
    delta: Optional[AttributeModel] = None

    def setup(self, registrar: PartRegistrar) -> None:
        super().setup(registrar)
        # Add some Attribute and Methods to the Block
        self.counter = NumberMeta(
            "float64",
            "The current value of the counter",
            tags=[config_tag(), Widget.TEXTINPUT.tag()],
        ).create_attribute_model()
        registrar.add_attribute_model("counter", self.counter,
                                      self.counter.set_value)

        self.delta = NumberMeta(
            "float64",
            "The amount to increment() by",
            tags=[config_tag(), Widget.TEXTINPUT.tag()],
        ).create_attribute_model(initial_value=1)
        registrar.add_attribute_model("delta", self.delta,
                                      self.delta.set_value)

        registrar.add_method_model(self.zero)
        registrar.add_method_model(self.increment)

    def zero(self):
        """Zero the counter attribute"""
        self.counter.set_value(0)

    def increment(self):
        """Add delta to the counter attribute"""
        self.counter.set_value(self.counter.value + self.delta.value)
예제 #3
0
class RunnableController(builtin.controllers.ManagerController):
    """RunnableDevice implementer that also exposes GUI for child parts"""

    # The state_set that this controller implements
    state_set = ss()

    def __init__(
        self,
        mri: AMri,
        config_dir: AConfigDir,
        template_designs: ATemplateDesigns = "",
        initial_design: AInitialDesign = "",
        description: ADescription = "",
    ) -> None:
        super().__init__(
            mri=mri,
            config_dir=config_dir,
            template_designs=template_designs,
            initial_design=initial_design,
            description=description,
        )
        # Shared contexts between Configure, Run, Pause, Seek, Resume
        self.part_contexts: Dict[Part, Context] = {}
        # Any custom ConfigureParams subclasses requested by Parts
        self.part_configure_params: PartConfigureParams = {}
        # Params passed to configure()
        self.configure_params: Optional[ConfigureParams] = None
        # Progress reporting dict of completed_steps for each part
        self.progress_updates: Optional[Dict[Part, int]] = None
        # Queue so that do_run can wait to see why it was aborted and resume if
        # needed
        self.resume_queue: Optional[Queue] = None
        # Stored for pause. If using breakpoints, it is a list of steps
        self.steps_per_run: List[int] = []
        # If the list of breakpoints is not empty, this will be true
        self.use_breakpoints: bool = False
        # Absolute steps where the run() returns
        self.breakpoint_steps: List[int] = []
        # Breakpoint index, modified in run() and pause()
        self.breakpoint_index: int = 0
        # Queue so we can wait for aborts to complete
        self.abort_queue: Optional[Queue] = None
        # Create sometimes writeable attribute for the current completed scan
        # step
        self.completed_steps = NumberMeta(
            "int32",
            "Readback of number of scan steps",
            tags=[Widget.METER.tag()],  # Widget.TEXTINPUT.tag()]
        ).create_attribute_model(0)
        self.field_registry.add_attribute_model(
            "completedSteps", self.completed_steps, self.pause
        )
        self.set_writeable_in(self.completed_steps, ss.PAUSED, ss.ARMED)
        # Create read-only attribute for the number of configured scan steps
        self.configured_steps = NumberMeta(
            "int32",
            "Number of steps currently configured",
            tags=[Widget.TEXTUPDATE.tag()],
        ).create_attribute_model(0)
        self.field_registry.add_attribute_model(
            "configuredSteps", self.configured_steps
        )
        # Create read-only attribute for the total number scan steps
        self.total_steps = NumberMeta(
            "int32", "Readback of number of scan steps", tags=[Widget.TEXTUPDATE.tag()]
        ).create_attribute_model(0)
        self.field_registry.add_attribute_model("totalSteps", self.total_steps)
        # Create the method models
        self.field_registry.add_method_model(self.validate)
        self.set_writeable_in(
            self.field_registry.add_method_model(self.configure), ss.READY, ss.FINISHED
        )
        self.set_writeable_in(self.field_registry.add_method_model(self.run), ss.ARMED)
        self.set_writeable_in(
            self.field_registry.add_method_model(self.abort),
            ss.READY,
            ss.CONFIGURING,
            ss.ARMED,
            ss.RUNNING,
            ss.POSTRUN,
            ss.PAUSED,
            ss.SEEKING,
            ss.FINISHED,
        )
        self.set_writeable_in(
            self.field_registry.add_method_model(self.pause),
            ss.ARMED,
            ss.PAUSED,
            ss.RUNNING,
            ss.FINISHED,
        )
        self.set_writeable_in(
            self.field_registry.add_method_model(self.resume), ss.PAUSED
        )
        # Override reset to work from aborted too
        self.set_writeable_in(
            self.field_registry.get_field("reset"),
            ss.FAULT,
            ss.DISABLED,
            ss.ABORTED,
            ss.ARMED,
            ss.FINISHED,
        )
        # Allow Parts to report their status
        self.info_registry.add_reportable(RunProgressInfo, self.update_completed_steps)
        # Allow Parts to request extra items from configure
        self.info_registry.add_reportable(
            ConfigureParamsInfo, self.update_configure_params
        )

    def get_steps_per_run(
        self,
        generator: CompoundGenerator,
        axes_to_move: AAxesToMove,
        breakpoints: List[int],
    ) -> List[int]:
        self.use_breakpoints = False
        steps = [1]
        axes_set = set(axes_to_move)
        for dim in reversed(generator.dimensions):
            # If the axes_set is empty and the dimension has axes then we have
            # done as many dimensions as we can, so return
            if dim.axes and not axes_set:
                break
            # Consume the axes that this generator scans
            for axis in dim.axes:
                assert axis in axes_set, f"Axis {axis} is not in {axes_to_move}"
                axes_set.remove(axis)
            # Now multiply by the dimensions to get the number of steps
            steps[0] *= dim.size

        # If we have breakpoints we make a list of steps
        if len(breakpoints) > 0:
            total_breakpoint_steps = sum(breakpoints)
            assert (
                total_breakpoint_steps <= steps[0]
            ), "Sum of breakpoints greater than steps in scan"
            self.use_breakpoints = True

            # Cast to list so we can append
            breakpoints_list = list(breakpoints)

            # Check if we need to add the final breakpoint to the inner scan
            if total_breakpoint_steps < steps[0]:
                last_breakpoint = steps[0] - total_breakpoint_steps
                breakpoints_list += [last_breakpoint]

            # Repeat the set of breakpoints for each outer step
            breakpoints_list *= self._get_outer_steps(generator, axes_to_move)

            steps = breakpoints_list

            # List of steps completed at end of each run
            self.breakpoint_steps = [sum(steps[:i]) for i in range(1, len(steps) + 1)]

        return steps

    def _get_outer_steps(self, generator, axes_to_move):
        outer_steps = 1
        for dim in reversed(generator.dimensions):
            outer_axis = True
            for axis in dim.axes:
                if axis in axes_to_move:
                    outer_axis = False
            if outer_axis:
                outer_steps *= dim.size
        return outer_steps

    def do_reset(self):
        super().do_reset()
        self.configured_steps.set_value(0)
        self.completed_steps.set_value(0)
        self.total_steps.set_value(0)
        self.breakpoint_index = 0

    def update_configure_params(
        self, part: Part = None, info: ConfigureParamsInfo = None
    ) -> None:
        """Tell controller part needs different things passed to Configure"""
        with self.changes_squashed:
            # Update the dict
            if part:
                assert info, "No info for part"
                self.part_configure_params[part] = info

            # No process yet, so don't do this yet
            if self.process is None:
                return

            # Make a list of all the infos that the parts have contributed
            part_configure_infos = []
            for part in self.parts.values():
                info = self.part_configure_params.get(part, None)
                if info:
                    part_configure_infos.append(info)

            # Update methods from the updated configure model
            for method_name in ("configure", "validate"):
                # Get the model of our configure method as the starting point
                method_meta = MethodMeta.from_callable(self.configure)
                # Update the configure model from the infos
                update_configure_model(method_meta, part_configure_infos)
                # Put the created metas onto our block meta
                method = self._block[method_name]
                method.meta.takes.set_elements(method_meta.takes.elements)
                method.meta.takes.set_required(method_meta.takes.required)
                method.meta.returns.set_elements(method_meta.returns.elements)
                method.meta.returns.set_required(method_meta.returns.required)
                method.meta.set_defaults(method_meta.defaults)
                method.set_took()
                method.set_returned()

    def update_block_endpoints(self):
        super().update_block_endpoints()
        self.update_configure_params()

    def _part_params(
        self, part_contexts: Dict[Part, Context] = None, params: ConfigureParams = None
    ) -> PartContextParams:
        if part_contexts is None:
            part_contexts = self.part_contexts
        if params is None:
            params = self.configure_params
        for part, context in part_contexts.items():
            args = {}
            assert params, "No params"
            for k in params.call_types:
                args[k] = getattr(params, k)
            yield part, context, args

    # This will be serialized, so maintain camelCase for axesToMove
    # noinspection PyPep8Naming
    @add_call_types
    def validate(
        self,
        generator: AGenerator,
        axesToMove: AAxesToMove = None,
        breakpoints: ABreakpoints = None,
        **kwargs: Any,
    ) -> AConfigureParams:
        """Validate configuration parameters and return validated parameters.

        Doesn't take device state into account so can be run in any state
        """
        iterations = 10
        # We will return this, so make sure we fill in defaults
        for k, default in self._block.configure.meta.defaults.items():
            kwargs.setdefault(k, default)
        # The validated parameters we will eventually return
        params = ConfigureParams(generator, axesToMove, breakpoints, **kwargs)
        # Make some tasks just for validate
        part_contexts = self.create_part_contexts()
        # Get any status from all parts
        status_part_info = self.run_hooks(
            ReportStatusHook(p, c) for p, c in part_contexts.items()
        )
        while iterations > 0:
            # Try up to 10 times to get a valid set of parameters
            iterations -= 1
            # Validate the params with all the parts
            validate_part_info = self.run_hooks(
                ValidateHook(p, c, status_part_info, **kwargs)
                for p, c, kwargs in self._part_params(part_contexts, params)
            )
            tweaks: List[ParameterTweakInfo] = ParameterTweakInfo.filter_values(
                validate_part_info
            )
            if tweaks:
                # Check if we need to resolve generator tweaks first
                generator_tweaks: List[ParameterTweakInfo] = []
                for tweak in tweaks:
                    # Collect all generator tweaks
                    if tweak.parameter == "generator":
                        generator_tweaks.append(tweak)
                if len(generator_tweaks) > 0:
                    # Resolve multiple tweaks to the generator
                    generator_tweak = resolve_generator_tweaks(generator_tweaks)
                    deserialized = self._block.configure.meta.takes.elements[
                        generator_tweak.parameter
                    ].validate(generator_tweak.value)
                    setattr(params, generator_tweak.parameter, deserialized)
                    self.log.debug(f"{self.mri}: tweaking generator to {deserialized}")
                else:
                    # Other tweaks can be applied at the same time
                    for tweak in tweaks:
                        deserialized = self._block.configure.meta.takes.elements[
                            tweak.parameter
                        ].validate(tweak.value)
                        setattr(params, tweak.parameter, deserialized)
                        self.log.debug(
                            f"{self.mri}: tweaking {tweak.parameter} to {deserialized}"
                        )
            else:
                # Consistent set, just return the params
                return params
        raise ValueError("Could not get a consistent set of parameters")

    def abortable_transition(self, state):
        with self._lock:
            # We might have been aborted just now, so this will fail
            # with an AbortedError if we were
            self_ctx = self.part_contexts.get(self, None)
            if self_ctx:
                self_ctx.sleep(0)
            self.transition(state)

    # This will be serialized, so maintain camelCase for axesToMove
    # noinspection PyPep8Naming
    @add_call_types
    def configure(
        self,
        generator: AGenerator,
        axesToMove: AAxesToMove = None,
        breakpoints: ABreakpoints = None,
        **kwargs: Any,
    ) -> AConfigureParams:
        """Validate the params then configure the device ready for run().

        Try to prepare the device as much as possible so that run() is quick to
        start, this may involve potentially long running activities like moving
        motors.

        Normally it will return in Armed state. If the user aborts then it will
        return in Aborted state. If something goes wrong it will return in Fault
        state. If the user disables then it will return in Disabled state.
        """
        params = self.validate(generator, axesToMove, breakpoints, **kwargs)
        state = self.state.value
        try:
            self.transition(ss.CONFIGURING)
            self.do_configure(state, params)
            self.abortable_transition(ss.ARMED)
        except AbortedError:
            assert self.abort_queue, "No abort queue"
            self.abort_queue.put(None)
            raise
        except Exception as e:
            self.go_to_error_state(e)
            raise
        else:
            return params

    def do_configure(self, state: str, params: ConfigureParams) -> None:
        if state == ss.FINISHED:
            # If we were finished then do a reset before configuring
            self.run_hooks(
                builtin.hooks.ResetHook(p, c)
                for p, c in self.create_part_contexts().items()
            )
        # Clear out any old part contexts now rather than letting gc do it
        for context in self.part_contexts.values():
            context.unsubscribe_all()
        # These are the part tasks that abort() and pause() will operate on
        self.part_contexts = self.create_part_contexts()
        # So add one for ourself too so we can be aborted
        assert self.process, "No attached process"
        self.part_contexts[self] = Context(self.process)
        # Store the params for use in seek()
        self.configure_params = params
        # Tell everything to get into the right state to Configure
        self.run_hooks(PreConfigureHook(p, c) for p, c in self.part_contexts.items())
        # This will calculate what we need from the generator, possibly a long
        # call
        params.generator.prepare()
        # Set the steps attributes that we will do across many run() calls
        self.total_steps.set_value(params.generator.size)
        self.completed_steps.set_value(0)
        self.configured_steps.set_value(0)
        # TODO: We can be cleverer about this and support a different number
        # of steps per run for each run by examining the generator structure
        self.steps_per_run = self.get_steps_per_run(
            params.generator, params.axesToMove, params.breakpoints
        )
        # Get any status from all parts
        part_info = self.run_hooks(
            ReportStatusHook(p, c) for p, c in self.part_contexts.items()
        )
        # Run the configure command on all parts, passing them info from
        # ReportStatus. Parts should return any reporting info for PostConfigure
        completed_steps = 0
        self.breakpoint_index = 0
        steps_to_do = self.steps_per_run[self.breakpoint_index]
        part_info = self.run_hooks(
            ConfigureHook(p, c, completed_steps, steps_to_do, part_info, **kw)
            for p, c, kw in self._part_params()
        )
        # Take configuration info and reflect it as attribute updates
        self.run_hooks(
            PostConfigureHook(p, c, part_info) for p, c in self.part_contexts.items()
        )
        # Update the completed and configured steps
        self.configured_steps.set_value(steps_to_do)
        self.completed_steps.meta.display.set_limitHigh(params.generator.size)
        # Reset the progress of all child parts
        self.progress_updates = {}
        self.resume_queue = Queue()

    @add_call_types
    def run(self) -> None:
        """Run a device where configure() has already be called

        Normally it will return in Ready state. If setup for multiple-runs with
        a single configure() then it will return in Armed state. If the user
        aborts then it will return in Aborted state. If something goes wrong it
        will return in Fault state. If the user disables then it will return in
        Disabled state.
        """

        if self.configured_steps.value < self.total_steps.value:
            next_state = ss.ARMED
        else:
            next_state = ss.FINISHED
        try:
            self.transition(ss.RUNNING)
            hook = RunHook
            going = True
            while going:
                try:
                    self.do_run(hook)
                    self.abortable_transition(next_state)
                except AbortedError:
                    assert self.abort_queue, "No abort queue"
                    self.abort_queue.put(None)
                    # Wait for a response on the resume_queue
                    assert self.resume_queue, "No resume queue"
                    should_resume = self.resume_queue.get()
                    if should_resume:
                        # we need to resume
                        self.log.debug("Resuming run")
                    else:
                        # we don't need to resume, just drop out
                        raise
                else:
                    going = False
        except AbortedError:
            raise
        except Exception as e:
            self.go_to_error_state(e)
            raise

    def do_run(self, hook):
        # type: (Type[ControllerHook]) -> None

        # Run all PreRunHooks
        self.run_hooks(PreRunHook(p, c) for p, c in self.part_contexts.items())

        self.run_hooks(hook(p, c) for p, c in self.part_contexts.items())
        self.abortable_transition(ss.POSTRUN)
        completed_steps = self.configured_steps.value
        if completed_steps < self.total_steps.value:
            if self.use_breakpoints:
                self.breakpoint_index += 1
            steps_to_do = self.steps_per_run[self.breakpoint_index]
            part_info = self.run_hooks(
                ReportStatusHook(p, c) for p, c in self.part_contexts.items()
            )
            self.completed_steps.set_value(completed_steps)
            self.run_hooks(
                PostRunArmedHook(
                    p, c, completed_steps, steps_to_do, part_info, **kwargs
                )
                for p, c, kwargs in self._part_params()
            )
            self.configured_steps.set_value(completed_steps + steps_to_do)
        else:
            self.completed_steps.set_value(completed_steps)
            self.run_hooks(
                PostRunReadyHook(p, c) for p, c in self.part_contexts.items()
            )

    def update_completed_steps(
        self, part: Part, completed_steps: RunProgressInfo
    ) -> None:
        with self._lock:
            # Update
            assert self.progress_updates is not None, "No progress updates"
            self.progress_updates[part] = completed_steps.steps
            min_completed_steps = min(self.progress_updates.values())
            if min_completed_steps > self.completed_steps.value:
                self.completed_steps.set_value(min_completed_steps)

    @add_call_types
    def abort(self) -> None:
        """Abort the current operation and block until aborted

        Normally it will return in Aborted state. If something goes wrong it
        will return in Fault state. If the user disables then it will return in
        Disabled state.
        """
        self.try_aborting_function(ss.ABORTING, ss.ABORTED, self.do_abort)
        # Tell _call_do_run not to resume
        if self.resume_queue:
            self.resume_queue.put(False)

    def do_abort(self) -> None:
        self.run_hooks(AbortHook(p, c) for p, c in self.create_part_contexts().items())

    def try_aborting_function(
        self, start_state: str, end_state: str, func: Callable[..., None], *args: Any
    ) -> None:
        try:
            # To make the running function fail we need to stop any running
            # contexts (if running a hook) or make transition() fail with
            # AbortedError. Both of these are accomplished here
            with self._lock:
                original_state = self.state.value
                self.abort_queue = Queue()
                self.transition(start_state)
                for context in self.part_contexts.values():
                    context.stop()
            if original_state not in (ss.READY, ss.ARMED, ss.PAUSED, ss.FINISHED):
                # Something was running, let it finish aborting
                try:
                    self.abort_queue.get(timeout=DEFAULT_TIMEOUT)
                except TimeoutError:
                    self.log.warning("Timeout waiting while {start_state}")
            with self._lock:
                # Now we've waited for a while we can remove the error state
                # for transition in case a hook triggered it rather than a
                # transition
                self_ctx = self.part_contexts.get(self, None)
                if self_ctx:
                    self_ctx.ignore_stops_before_now()
            func(*args)
            self.abortable_transition(end_state)
        except AbortedError:
            assert self.abort_queue, "No abort queue"
            self.abort_queue.put(None)
            raise
        except Exception as e:  # pylint:disable=broad-except
            self.go_to_error_state(e)
            raise

    # Allow camelCase as this will be serialized
    # noinspection PyPep8Naming
    @add_call_types
    def pause(self, lastGoodStep: ALastGoodStep = -1) -> None:
        """Pause a run() so that resume() can be called later, or seek within
        an Armed or Paused state.

        The original call to run() will not be interrupted by pause(), it will
        wait until the scan completes or is aborted.

        Normally it will return in Paused state. If the scan is finished it
        will return in Finished state. If the scan is armed it will return in
        Armed state. If the user aborts then it will return in Aborted state.
        If something goes wrong it will return in Fault state. If the user
        disables then it will return in Disabled state.
        """

        total_steps = self.total_steps.value

        # We need to decide where to go
        if lastGoodStep < 0:
            # If we are finished we do not need to do anything
            if self.state.value is ss.FINISHED:
                return
            # Otherwise set to number of completed steps
            else:
                lastGoodStep = self.completed_steps.value
        # Otherwise make sure we are bound to the total steps of the scan
        elif lastGoodStep >= total_steps:
            lastGoodStep = total_steps - 1

        if self.state.value in [ss.ARMED, ss.FINISHED]:
            # We don't have a run process, free to go anywhere we want
            next_state = ss.ARMED
        else:
            # Need to pause within the bounds of the current run
            if lastGoodStep == self.configured_steps.value:
                lastGoodStep -= 1
            next_state = ss.PAUSED

        self.try_aborting_function(ss.SEEKING, next_state, self.do_pause, lastGoodStep)

    def do_pause(self, completed_steps: int) -> None:
        """Recalculates the number of configured steps
        Arguments:
        completed_steps -- Last good step performed
        """
        self.run_hooks(PauseHook(p, c) for p, c in self.create_part_contexts().items())

        if self.use_breakpoints:
            self.breakpoint_index = self.get_breakpoint_index(completed_steps)
            in_run_steps = (
                completed_steps % self.breakpoint_steps[self.breakpoint_index]
            )
            steps_to_do = self.breakpoint_steps[self.breakpoint_index] - in_run_steps
        else:
            in_run_steps = completed_steps % self.steps_per_run[self.breakpoint_index]
            steps_to_do = self.steps_per_run[self.breakpoint_index] - in_run_steps

        part_info = self.run_hooks(
            ReportStatusHook(p, c) for p, c in self.part_contexts.items()
        )
        self.completed_steps.set_value(completed_steps)
        self.run_hooks(
            SeekHook(p, c, completed_steps, steps_to_do, part_info, **kwargs)
            for p, c, kwargs in self._part_params()
        )
        self.configured_steps.set_value(completed_steps + steps_to_do)

    def get_breakpoint_index(self, completed_steps: int) -> int:
        # If the last point, then return the last index
        if completed_steps == self.breakpoint_steps[-1]:
            return len(self.breakpoint_steps) - 1
        # Otherwise check which index we fall within
        index = 0
        while completed_steps >= self.breakpoint_steps[index]:
            index += 1
        return index

    @add_call_types
    def resume(self) -> None:
        """Resume a paused scan.

        Normally it will return in Running state. If something goes wrong it
        will return in Fault state.
        """
        self.transition(ss.RUNNING)
        assert self.resume_queue, "No resume queue"
        self.resume_queue.put(True)
        # self.run will now take over

    def do_disable(self) -> None:
        # Abort anything that is currently running, but don't wait
        for context in self.part_contexts.values():
            context.stop()
        if self.resume_queue:
            self.resume_queue.put(False)
        super().do_disable()

    def go_to_error_state(self, exception):
        if self.resume_queue:
            self.resume_queue.put(False)
        super().go_to_error_state(exception)
예제 #4
0
class PandAManagerController(builtin.controllers.ManagerController):
    def __init__(
        self,
        mri: AMri,
        config_dir: AConfigDir,
        hostname: AHostname = "localhost",
        port: APort = 8888,
        doc_url_base: ADocUrlBase = DOC_URL_BASE,
        poll_period: APollPeriod = 0.1,
        template_designs: ATemplateDesigns = "",
        initial_design: AInitialDesign = "",
        use_git: AUseGit = True,
        description: ADescription = "",
    ) -> None:
        super().__init__(
            mri=mri,
            config_dir=config_dir,
            template_designs=template_designs,
            initial_design=initial_design,
            use_git=use_git,
            description=description,
        )
        self._poll_period = poll_period
        self._doc_url_base = doc_url_base
        # All the bit_out fields and their values
        # {block_name.field_name: value}
        self._bit_outs: Dict[str, bool] = {}
        # The bit_out field values that need toggling since the last handle
        # {block_name.field_name: value}
        self._bit_out_changes: Dict[str, bool] = {}
        # The fields that busses needs to know about
        # {block_name.field_name[.subfield_name]}
        self._bus_fields: Set[str] = set()
        # The child controllers we have created
        self._child_controllers: Dict[str, PandABlockController] = {}
        # The PandABlock client that does the comms
        self._client = PandABlocksClient(hostname, port, Queue)
        # Filled in on reset
        self._stop_queue = None
        self._poll_spawned = None
        # Poll period reporting
        self.last_poll_period = NumberMeta(
            "float64",
            "The time between the last 2 polls of the hardware",
            tags=[Widget.TEXTUPDATE.tag()],
            display=Display(units="s", precision=3),
        ).create_attribute_model(poll_period)
        self.field_registry.add_attribute_model("lastPollPeriod",
                                                self.last_poll_period)
        # Bus tables
        self.busses: PandABussesPart = self._make_busses()
        self.add_part(self.busses)

    def do_init(self):
        # start the poll loop and make block parts first to fill in our parts
        # before calling _set_block_children()
        self.start_poll_loop()
        super().do_init()

    def start_poll_loop(self):
        # queue to listen for stop events
        if not self._client.started:
            self._stop_queue = Queue()
            if self._client.started:
                self._client.stop()
            self._client.start(self.process.spawn, socket)
        if not self._child_controllers:
            self._make_child_controllers()
        if self._poll_spawned is None:
            self._poll_spawned = self.process.spawn(self._poll_loop)

    def do_disable(self):
        super().do_disable()
        self.stop_poll_loop()

    def do_reset(self):
        self.start_poll_loop()
        super().do_reset()

    def _poll_loop(self):
        """At self.poll_period poll for changes"""
        last_poll_update = time.time()
        next_poll = time.time() + self._poll_period
        try:
            while True:
                # Need to make sure we don't consume all the CPU, allow us to be
                # active for 50% of the poll period, so we must sleep at least
                # 50% of the poll period
                min_sleep = self._poll_period * 0.5
                sleep_for = next_poll - time.time()
                if sleep_for < min_sleep:
                    # Going too fast, slow down a bit
                    last_poll_period = self._poll_period + min_sleep - sleep_for
                    sleep_for = min_sleep
                else:
                    last_poll_period = self._poll_period
                try:
                    # If told to stop, we will get something here and return
                    return self._stop_queue.get(timeout=sleep_for)
                except TimeoutError:
                    # No stop, no problem
                    pass
                # Poll for changes
                self.handle_changes(self._client.get_changes())
                if (last_poll_period != self.last_poll_period.value
                        and next_poll - last_poll_update > POLL_PERIOD_REPORT):
                    self.last_poll_period.set_value(last_poll_period)
                    last_poll_update = next_poll
                next_poll += last_poll_period
        except Exception as e:
            self.go_to_error_state(e)
            raise

    def stop_poll_loop(self):
        if self._poll_spawned:
            self._stop_queue.put(None)
            self._poll_spawned.wait()
            self._poll_spawned = None
        if self._client.started:
            self._client.stop()

    def _make_child_controllers(self):
        self._child_controllers = {}
        controllers = []
        child_parts = []
        pos_names = []
        blocks_data = self._client.get_blocks_data()
        for block_rootname, block_data in blocks_data.items():
            block_names = []
            if block_data.number == 1:
                block_names.append(block_rootname)
            else:
                for i in range(block_data.number):
                    block_names.append("%s%d" % (block_rootname, i + 1))
            for block_name in block_names:
                # Look through the BlockData for things we are interested in
                for field_name, field_data in block_data.fields.items():
                    if field_data.field_type == "pos_out":
                        pos_names.append("%s.%s" % (block_name, field_name))

                # Make the child controller and add it to the process
                controller, child_part = self._make_child_block(
                    block_name, block_data)
                controllers += [controller]
                child_parts += [child_part]
                self._child_controllers[block_name] = controller
                # If there is only one, make an alias with "1" appended for
                # *METADATA.LABEL lookup
                if block_data.number == 1:
                    self._child_controllers[block_name + "1"] = controller

        self.process.add_controllers(controllers)
        for part in child_parts:
            self.add_part(part)

        # Create the busses from their initial sets of values
        pcap_bit_fields = self._client.get_pcap_bits_fields()
        self.busses.create_busses(pcap_bit_fields, pos_names)
        # Handle the pos_names that busses needs
        self._bus_fields = set(pos_names)
        for pos_name in pos_names:
            for suffix in ("CAPTURE", "UNITS", "SCALE", "OFFSET"):
                self._bus_fields.add("%s.%s" % (pos_name, suffix))
        # Handle the bit_outs, keeping a list for toggling and adding them
        # to the set of things that the busses need
        self._bit_outs = {k: 0 for k in self.busses.bits.value.name}
        self._bit_out_changes = {}
        self._bus_fields |= set(self._bit_outs)
        for capture_field in pcap_bit_fields:
            self._bus_fields.add(capture_field)
        # Handle the initial set of changes to get an initial value
        self.handle_changes(self._client.get_changes())
        # Then once more to let bit_outs toggle back
        self.handle_changes(())
        assert not self._bit_out_changes, (
            "There are still bit_out changes %s" % self._bit_out_changes)

    def _make_busses(self) -> PandABussesPart:
        return PandABussesPart("busses", self._client)

    def _make_child_block(self, block_name, block_data):
        controller = PandABlockController(self._client, self.mri, block_name,
                                          block_data, self._doc_url_base)
        if block_name == "PCAP":
            controller.add_part(
                PandAActionPart(self._client, "*PCAP", "ARM",
                                "Arm position capture", []))
            controller.add_part(
                PandAActionPart(self._client, "*PCAP", "DISARM",
                                "Disarm position capture", []))
        child_part = builtin.parts.ChildPart(name=block_name,
                                             mri=controller.mri,
                                             stateful=False)
        return controller, child_part

    def _handle_change(self, k, v, bus_changes, block_changes,
                       bit_out_changes):
        # Handle bit changes
        try:
            current_v = self._bit_outs[k]
        except KeyError:
            # Not a bit
            pass
        else:
            # Convert to a boolean
            v = bool(int(v))
            try:
                changed_to = bit_out_changes[k]
            except KeyError:
                # We didn't already make a change
                if v == current_v:
                    # Value is the same, store the negation, and set it
                    # back next time
                    self._bit_out_changes[k] = v
                    v = not v
            else:
                # Already made a change, defer this value til next time
                # if it is different
                if changed_to != v:
                    self._bit_out_changes[k] = v
                return
            self._bit_outs[k] = v

        # Notify the bus tables if they need to know
        if k in self._bus_fields:
            bus_changes[k] = v

        # Add to the relevant Block changes dict
        block_name, field_name = k.split(".", 1)
        if block_name == "*METADATA":
            if field_name.startswith("LABEL_"):
                field_name, block_name = field_name.split("_", 1)
            else:
                # Don't support any non-label metadata fields at the moment
                return
        block_changes.setdefault(block_name, {})[field_name] = v

    def handle_changes(self, changes: Sequence[Tuple[str, str]]) -> None:
        ts = TimeStamp()
        # {block_name: {field_name: field_value}}
        block_changes: Dict[str, Any] = {}
        # {full_field: field_value}
        bus_changes = {}

        # Process bit outs that need changing
        bit_out_changes = self._bit_out_changes
        self._bit_out_changes = {}
        for k, v in bit_out_changes.items():
            self._bit_outs[k] = v
            bus_changes[k] = v
            block_name, field_name = k.split(".")
            block_changes.setdefault(block_name, {})[field_name] = v

        # Work out which change is needed for which block
        for key, value in changes:
            self._handle_change(key, value, bus_changes, block_changes,
                                bit_out_changes)

        # Notify the Blocks that they need to handle these changes
        if bus_changes:
            self.busses.handle_changes(bus_changes, ts)
        for block_name, block_changes_values in block_changes.items():
            self._child_controllers[block_name].handle_changes(
                block_changes_values, ts)
예제 #5
0
class ScanRunnerPart(ChildPart):
    """Used to run sets of scans defined in a YAML file with a scan block"""
    def __init__(self, name: APartName, mri: AMri) -> None:
        super().__init__(name, mri, stateful=False, initial_visibility=True)
        self.runner_config = None
        self.context: Optional[AContext] = None
        self.scan_sets: Dict[str, Scan] = {}

        self.runner_state = StringMeta(
            "Runner state",
            tags=Widget.TEXTUPDATE.tag()).create_attribute_model("Idle")
        self.runner_status_message = StringMeta(
            "Runner status message",
            tags=Widget.TEXTUPDATE.tag()).create_attribute_model("Idle")
        self.scan_file = StringMeta(
            "Path to input scan file",
            tags=[config_tag(),
                  Widget.TEXTINPUT.tag()]).create_attribute_model()
        self.scans_configured = NumberMeta(
            "int64",
            "Number of configured scans",
            tags=Widget.TEXTUPDATE.tag()).create_attribute_model()
        self.current_scan_set = StringMeta(
            "Current scan set",
            tags=Widget.TEXTUPDATE.tag()).create_attribute_model()
        self.scans_completed = NumberMeta(
            "int64", "Number of scans completed",
            tags=Widget.TEXTUPDATE.tag()).create_attribute_model()
        self.scan_successes = NumberMeta("int64",
                                         "Successful scans",
                                         tags=[Widget.TEXTUPDATE.tag()
                                               ]).create_attribute_model()
        self.scan_failures = NumberMeta("int64",
                                        "Failed scans",
                                        tags=[Widget.TEXTUPDATE.tag()
                                              ]).create_attribute_model()
        self.output_directory = StringMeta(
            "Root output directory (will create a sub-directory inside)",
            tags=[config_tag(), Widget.TEXTINPUT.tag()],
        ).create_attribute_model()

    def setup(self, registrar: PartRegistrar) -> None:
        super().setup(registrar)

        # Register attributes
        registrar.add_attribute_model("runnerState", self.runner_state,
                                      self.runner_state.set_value)
        registrar.add_attribute_model(
            "runnerStatusMessage",
            self.runner_status_message,
            self.runner_status_message.set_value,
        )
        registrar.add_attribute_model("scanFile", self.scan_file,
                                      self.scan_file.set_value)
        registrar.add_attribute_model("scansConfigured", self.scans_configured,
                                      self.scans_configured.set_value)
        registrar.add_attribute_model("currentScanSet", self.current_scan_set,
                                      self.current_scan_set.set_value)
        registrar.add_attribute_model("scansCompleted", self.scans_completed,
                                      self.scans_completed.set_value)
        registrar.add_attribute_model("scanSuccesses", self.scan_successes,
                                      self.scan_successes.set_value)
        registrar.add_attribute_model("scanFailures", self.scan_failures,
                                      self.scan_failures.set_value)
        registrar.add_attribute_model("outputDirectory", self.output_directory,
                                      self.output_directory.set_value)

        # Methods
        registrar.add_method_model(self.loadFile)
        registrar.add_method_model(self.run, needs_context=True)
        registrar.add_method_model(self.abort, needs_context=True)

    def get_file_contents(self) -> str:

        try:
            with open(self.scan_file.value, "r") as input_file:
                return input_file.read()
        except IOError:
            self.set_runner_state(RunnerStates.FAULT)
            self.runner_status_message.set_value("Could not read scan file")
            raise

    def parse_yaml(self, string: str) -> Any:
        try:
            yaml = YAML(typ="safe", pure=True)
            parsed_yaml = yaml.load(string)
            return parsed_yaml
        except YAMLError:
            self.set_runner_state(RunnerStates.FAULT)
            self.runner_status_message.set_value("Could not parse scan file")
            raise

    @staticmethod
    def get_kwargs_from_dict(input_dict, kwargs_list):
        kwargs = {}
        if not isinstance(kwargs_list, list):
            kwargs_list = [kwargs_list]
        for kwarg in kwargs_list:
            if kwarg in input_dict:
                kwargs[kwarg] = input_dict[kwarg]
        return kwargs

    @staticmethod
    def parse_compound_generator(entry: dict) -> CompoundGenerator:
        generators = []
        generators_dict = entry["generators"]
        for generator in generators_dict:
            generators.append(LineGenerator.from_dict(generator["line"]))

        entry["generators"] = generators
        compound_generator = CompoundGenerator.from_dict(entry)
        if compound_generator.duration <= 0.0:
            raise ValueError(
                "Negative generator duration - is it missing from the YAML?")
        return compound_generator

    def parse_scan(self, entry: dict) -> None:
        name = entry["name"]
        generator = self.parse_compound_generator(entry["generator"])
        kwargs = self.get_kwargs_from_dict(entry, "repeats")

        self.scan_sets[name] = Scan(name, generator, **kwargs)

    @staticmethod
    def get_current_datetime(time_separator: str = ":") -> str:
        return datetime.now().strftime(
            f"%Y-%m-%d-%H{time_separator}%M{time_separator}%S")

    # noinspection PyPep8Naming
    def loadFile(self) -> None:

        # Update state
        self.set_runner_state(RunnerStates.LOADING)
        self.runner_status_message.set_value("Loading scan file")

        # Read contents of file into string
        file_contents = self.get_file_contents()

        # Parse the string
        parsed_yaml = self.parse_yaml(file_contents)

        # Empty the current dictionaries
        self.scan_sets = {}

        # Parse the configuration
        for item in parsed_yaml:
            key_name = list(item.keys())[0].upper()
            if key_name == EntryType.SCAN.name:
                self.parse_scan(item["scan"])
            else:
                self.set_runner_state(RunnerStates.FAULT)
                self.runner_status_message.value = "Unidentified key in YAML"
                raise ValueError(f"Unidentified object in YAML: {key_name}")

        # Count the number of scans configured
        self.update_scans_configured()

        self.set_runner_state(RunnerStates.CONFIGURED)
        self.runner_status_message.set_value("Load complete")

    def update_scans_configured(self) -> None:
        number_of_scans = 0
        for key in self.scan_sets:
            number_of_scans += self.scan_sets[key].repeats
        self.scans_configured.set_value(number_of_scans)

    def create_directory(self, directory: str) -> None:
        try:
            os.mkdir(directory)
        except OSError:
            self.set_runner_state(RunnerStates.FAULT)
            self.runner_status_message.set_value("Could not create directory")
            raise IOError(f"ERROR: unable to create directory: {directory}")

    def create_and_get_sub_directory(self, root_directory: str) -> str:
        today_str = self.get_current_datetime(time_separator="-")
        sub_directory = f"{root_directory}/{self.mri}-{today_str}"
        self.create_directory(sub_directory)
        return sub_directory

    def get_root_directory(self):
        root_directory = self.output_directory.value
        if root_directory[-1] == "/":
            root_directory = root_directory[:-1]
        return root_directory

    @add_call_types
    def abort(self, context: AContext) -> None:
        if self.context:
            # Stop the context
            self.context.stop()
            # Stop the current scan
            context.block_view(self.mri).abort()
            # Update status
            self.set_runner_state(RunnerStates.ABORTED)
            self.runner_status_message.set_value("Aborted scans")

    @add_call_types
    def run(self, context: AContext) -> None:

        # Check that we have loaded some scan sets
        if len(self.scan_sets) == 0:
            self.runner_status_message.set_value("No scan file loaded")
            raise ValueError(
                "No scan sets configured. Have you loaded a YAML file?")

        # Root file directory
        root_directory = self.get_root_directory()

        # Sub-directory to create for this run
        sub_directory = self.create_and_get_sub_directory(root_directory)

        # Top-level report filepath
        report_filepath = f"{sub_directory}/report.txt"

        # Reset counters and set state
        self.scans_completed.set_value(0)
        self.scan_successes.set_value(0)
        self.scan_failures.set_value(0)
        self.set_runner_state(RunnerStates.RUNNING)

        # Get our scan block and store context
        self.context = context
        scan_block = self.context.block_view(self.mri)

        # Cycle through the scan sets
        for key in self.scan_sets:
            self.run_scan_set(self.scan_sets[key], scan_block, sub_directory,
                              report_filepath)

        self.set_runner_state(RunnerStates.FINISHED)
        self.current_scan_set.set_value("")
        self.runner_status_message.set_value("Scans complete")

    def create_and_get_set_directory(self, sub_directory: str,
                                     set_name: str) -> str:
        set_directory = f"{sub_directory}/scanset-{set_name}"
        self.create_directory(set_directory)
        return set_directory

    def run_scan_set(
        self,
        scan_set: Scan,
        scan_block: Any,
        sub_directory: str,
        report_filepath: str,
    ) -> None:
        # Update scan set
        self.current_scan_set.set_value(scan_set.name)

        # Directory where to save scans for this set
        set_directory = self.create_and_get_set_directory(
            sub_directory, scan_set.name)

        # Run each scan
        for scan_number in range(1, scan_set.repeats + 1):
            self.run_scan(
                scan_set.name,
                scan_block,
                set_directory,
                scan_number,
                report_filepath,
                scan_set.generator,
            )

    def create_and_get_scan_directory(self, set_directory: str,
                                      scan_number: int) -> str:
        scan_directory = f"{set_directory}/scan-{scan_number}"
        self.create_directory(scan_directory)
        return scan_directory

    @staticmethod
    def scan_is_aborting(scan_block):
        return scan_block.state.value is RunnableStates.ABORTING

    def run_scan(
        self,
        set_name: str,
        scan_block: Any,
        set_directory: str,
        scan_number: int,
        report_filepath: str,
        generator: CompoundGenerator,
    ) -> None:
        self.runner_status_message.set_value(
            f"Running {set_name}: {scan_number}")
        assert self.context, "No context found"

        # Make individual scan directory
        scan_directory = self.create_and_get_scan_directory(
            set_directory, scan_number)

        # Check if scan can be reset or run
        while self.scan_is_aborting(scan_block):
            self.context.sleep(0.1)

        # Run the scan and capture the outcome
        if scan_block.state.value is not RunnableStates.READY:
            scan_block.reset()

        # Configure first
        outcome = None
        try:
            scan_block.configure(generator, fileDir=scan_directory)
        except AssertionError:
            outcome = ScanOutcome.MISCONFIGURED
        except Exception as e:
            outcome = ScanOutcome.MISCONFIGURED
            self.log.error(
                f"Unhandled exception for scan {scan_number} in {set_name}: "
                f"({type(e)}) {e}")

        # Run if configure was successful
        start_time = self.get_current_datetime()
        if outcome is None:
            try:
                scan_block.run()
            except TimeoutError:
                outcome = ScanOutcome.TIMEOUT
            except NotWriteableError:
                outcome = ScanOutcome.NOTWRITEABLE
            except AbortedError:
                outcome = ScanOutcome.ABORTED
            except AssertionError:
                outcome = ScanOutcome.FAIL
            except Exception as e:
                outcome = ScanOutcome.OTHER
                self.log.error((
                    f"Unhandled exception for scan {scan_number} in {set_name}: "
                    f"({type(e)}) {e}"))
            else:
                outcome = ScanOutcome.SUCCESS

        # Record the outcome
        end_time = self.get_current_datetime()
        report_string = self.get_report_string(set_name, scan_number, outcome,
                                               start_time, end_time)
        self.add_report_line(report_filepath, report_string)

        if outcome is ScanOutcome.SUCCESS:
            self.increment_scan_successes()
        else:
            self.increment_scan_failures()

    def increment_scan_successes(self):
        self.scan_successes.set_value(self.scan_successes.value + 1)
        self.increment_scans_completed()

    def increment_scan_failures(self):
        self.scan_failures.set_value(self.scan_failures.value + 1)
        self.increment_scans_completed()

    def increment_scans_completed(self):
        self.scans_completed.set_value(self.scans_completed.value + 1)

    def get_report_string(
        self,
        set_name: str,
        scan_number: int,
        scan_outcome: ScanOutcome,
        start_time: str,
        end_time: str,
    ) -> str:

        report_str = (
            f"{set_name:<30}{scan_number:<10}{self.get_enum_label(scan_outcome):<14}"
            f"{start_time:<20}{end_time}")
        return report_str

    def add_report_line(self, report_filepath: str,
                        report_string: str) -> None:
        try:
            with open(report_filepath, "a+") as report_file:
                report_file.write(f"{report_string}\n")
        except IOError:
            self.set_runner_state(RunnerStates.FAULT)
            self.runner_status_message.set_value("Error writing report file")
            raise IOError(f"Could not write to report file {report_filepath}")

    @staticmethod
    def get_enum_label(enum_state: Enum) -> str:
        return enum_state.name.capitalize()

    def set_runner_state(self, runner_state: RunnerStates) -> None:
        self.runner_state.set_value(self.get_enum_label(runner_state))