コード例 #1
0
class FieldRegistry(object):
    def __init__(self):
        # type: () -> None
        self.fields = OrderedDict()  # type: FieldDict

    def get_field(self, name):
        # type: (str) -> Field
        for fields in self.fields.values():
            for (n, field, _) in fields:
                if n == name:
                    return field
        raise ValueError("No field named %s found" % (name,))

    def add_method_model(self,
                         func,  # type: Callable
                         name=None,  # type: Optional[str]
                         description=None,  # type: Optional[str]
                         owner=None,  # type: object
                         ):
        # type: (...) -> MethodModel
        """Register a function to be added to the block"""
        if name is None:
            name = func.__name__
        method = MethodModel.from_callable(func, description)
        self._add_field(owner, name, method, func)
        return method

    def add_attribute_model(self,
                            name,  # type: str
                            attr,  # type: AttributeModel
                            writeable_func=None,  # type: Optional[Callable]
                            owner=None,  # type: object
                            ):
        # type: (...) -> AttributeModel
        self._add_field(owner, name, attr, writeable_func)
        return attr

    def _add_field(self, owner, name, model, writeable_func):
        # type: (object, str, Field, Callable) -> None
        assert CAMEL_RE.match(name), \
            "Field %r published by %s is not camelCase" % (name, owner)
        part_fields = self.fields.setdefault(owner, [])
        part_fields.append((name, model, writeable_func))
コード例 #2
0
ファイル: process.py プロジェクト: thomascobb/pymalcolm
class Process(Loggable):
    """Hosts a number of Controllers and provides spawn capabilities"""

    def __init__(self, name: str = "Process") -> None:
        self.set_logger(process_name=name)
        self.name = name
        self._controllers = OrderedDict()  # mri -> Controller
        self._unpublished: Set[str] = set()  # [mri] for unpublishable controllers
        self.state = STOPPED
        self._spawned: List[Spawned] = []
        self._spawn_count = 0

    def start(self, timeout=DEFAULT_TIMEOUT):
        """Start the process going

        Args:
            timeout (float): Maximum amount of time to wait for each spawned
                process. None means forever
        """
        assert self.state == STOPPED, "Process already started"
        self.state = STARTING
        should_publish = self._start_controllers(self._controllers.values(), timeout)
        if should_publish:
            self._publish_controllers(timeout)
        self.state = STARTED

    def _start_controllers(
        self, controller_list: List[Controller], timeout: float = None
    ) -> bool:
        # Start just the given controller_list
        infos = self._run_hook(ProcessStartHook, controller_list, timeout=timeout)
        info: UnpublishedInfo
        new_unpublished = set()
        for info in UnpublishedInfo.filter_values(infos):
            new_unpublished.add(info.mri)
        self._unpublished |= new_unpublished
        if len(controller_list) > len(new_unpublished):
            return True
        else:
            return False

    def _publish_controllers(self, timeout):
        tree = OrderedDict()
        is_child = set()

        def add_controller(controller: Controller) -> OrderedDict:
            children = OrderedDict()
            tree[controller.mri] = children
            for part in controller.parts.values():
                part_mri = getattr(part, "mri", None)
                is_child.add(part_mri)
                if part_mri in tree:
                    children[part_mri] = tree[part_mri]
                elif part_mri in self._controllers:
                    children[part_mri] = add_controller(self._controllers[part_mri])
            return tree[controller.mri]

        for c in self._controllers.values():
            if c.mri not in is_child:
                add_controller(c)

        published = []

        def walk(d, not_at_this_level=()):
            to_do = []
            for k, v in d.items():
                if k in not_at_this_level:
                    continue
                if k not in published and k not in self._unpublished:
                    published.append(k)
                if v:
                    to_do.append(v)
            for v in to_do:
                walk(v)

        walk(tree, not_at_this_level=is_child)

        self._run_hook(ProcessPublishHook, timeout=timeout, published=published)

    def _run_hook(self, hook, controller_list=None, timeout=None, **kwargs):
        # Run the given hook waiting til all hooked functions are complete
        # but swallowing any errors
        if controller_list is None:
            controller_list = self._controllers.values()
        hooks = [
            hook(controller, **kwargs).set_spawn(self.spawn)
            for controller in controller_list
        ]
        hook_queue, hook_spawned = start_hooks(hooks)
        infos = wait_hooks(
            self.log, hook_queue, hook_spawned, timeout, exception_check=False
        )
        problems = [mri for mri, e in infos.items() if isinstance(e, Exception)]
        if problems:
            self.log.warning("Problem running %s on %s", hook.__name__, problems)
        return infos

    def stop(self, timeout=DEFAULT_TIMEOUT):
        """Stop the process and wait for it to finish

        Args:
            timeout (float): Maximum amount of time to wait for each spawned
                object. None means forever
        """
        assert self.state == STARTED, "Process not started"
        self.state = STOPPING
        # Allow every controller a chance to clean up
        self._run_hook(ProcessStopHook, timeout=timeout)
        for s in self._spawned:
            if not s.ready():
                self.log.debug(
                    "Waiting for %s *%s **%s", s._function, s._args, s._kwargs
                )
            try:
                s.wait(timeout=timeout)
            except TimeoutError:
                self.log.warning(
                    "Timeout waiting for %s *%s **%s", s._function, s._args, s._kwargs
                )
                raise
        self._spawned = []
        self._controllers = OrderedDict()
        self._unpublished = set()
        self.state = STOPPED
        self.log.debug("Done process.stop()")

    def spawn(self, function: Callable[..., Any], *args: Any, **kwargs: Any) -> Spawned:
        """Runs the function in a worker thread, returning a Result object

        Args:
            function: Function to run
            args: Positional arguments to run the function with
            kwargs: Keyword arguments to run the function with

        Returns:
            Spawned: Something you can call wait(timeout) on to see when it's
                finished executing
        """
        assert self.state != STOPPED, "Can't spawn when process stopped"
        spawned = Spawned(function, args, kwargs)
        self._spawned.append(spawned)
        self._spawn_count += 1
        # Filter out things that are ready to avoid memory leaks
        if self._spawn_count > SPAWN_CLEAR_COUNT:
            self._clear_spawn_list()
        return spawned

    def _clear_spawn_list(self) -> None:
        self._spawn_count = 0
        self._spawned = [s for s in self._spawned if not s.ready()]

    def add_controllers(
        self, controllers: List[Controller], timeout: float = None
    ) -> None:
        """Add many controllers to be hosted by this process

        Args:
            controllers (List[Controller]): List of its controller
            timeout (float): Maximum amount of time to wait for each spawned
                object. None means forever
        """
        for controller in controllers:
            assert (
                controller.mri not in self._controllers
            ), f"Controller already exists for {controller.mri}"
            self._controllers[controller.mri] = controller
            controller.setup(self)
        if self.state:
            should_publish = self._start_controllers(controllers, timeout)
            if self.state == STARTED and should_publish:
                self._publish_controllers(timeout)

    def add_controller(self, controller: Controller, timeout: float = None) -> None:
        """Add a controller to be hosted by this process

        Args:
            controller (Controller): Its controller
            timeout (float): Maximum amount of time to wait for each spawned
                object. None means forever
        """
        self.add_controllers([controller], timeout=timeout)

    @property
    def mri_list(self) -> List[str]:
        return list(self._controllers)

    def get_controller(self, mri: str) -> Controller:
        """Get controller which can make Block views for this mri"""
        try:
            return self._controllers[mri]
        except KeyError:
            raise ValueError(f"No controller registered for mri '{mri}'")

    def block_view(self, mri: str) -> Any:
        """Get a Block view from a Controller with given mri"""
        controller = self.get_controller(mri)
        block = controller.block_view()
        return block
コード例 #3
0
ファイル: pandatablepart.py プロジェクト: hir12111/pymalcolm
class PandATablePart(PandAFieldPart):
    """This will normally be instantiated by the PandABox assembly, not created
    in yaml"""
    def __init__(
        self,
        client: AClient,
        meta: AMeta,
        block_name: ABlockName,
        field_name: AFieldName,
    ) -> None:
        # Fill in the meta object with the correct headers
        columns = OrderedDict()
        self.field_data = OrderedDict()
        fields = client.get_table_fields(block_name, field_name)
        if not fields:
            # Didn't put any metadata in, make some up
            fields["VALUE"] = TableFieldData(31, 0, "The Value", None, True)
        for column_name, field_data in fields.items():
            nbits = field_data.bits_hi - field_data.bits_lo + 1
            if nbits < 1:
                raise ValueError("Bad bits in %s" % (field_data, ))
            if field_data.labels:
                column_meta = ChoiceArrayMeta(choices=field_data.labels)
                widget = Widget.COMBO
            elif nbits == 1:
                column_meta = BooleanArrayMeta()
                widget = Widget.CHECKBOX
            else:
                dtype = get_dtype(nbits, field_data.signed)
                column_meta = NumberArrayMeta(dtype)
                widget = Widget.TEXTINPUT
            column_name = snake_to_camel(column_name)
            column_meta.set_label(camel_to_title(column_name))
            column_meta.set_tags([widget.tag()])
            column_meta.set_description(field_data.description)
            column_meta.set_writeable(True)
            columns[column_name] = column_meta
            self.field_data[column_name] = field_data
        meta.set_elements(columns)
        # Work out how many ints per row
        # TODO: this should be in the block data
        max_bits_hi = max(f.bits_hi for f in self.field_data.values())
        self.ints_per_row = int((max_bits_hi + 31) / 32)
        # Superclass will make the attribute for us
        super().__init__(client, meta, block_name, field_name)

    def handle_change(self, value: str, ts: TimeStamp) -> None:
        value = self.table_from_list(value)
        self.attr.set_value_alarm_ts(value, Alarm.ok, ts)

    def set_field(self, value):
        int_values = self.list_from_table(value)
        self.client.set_table(self.block_name, self.field_name, int_values)

    def list_from_table(self, table):
        # Create a bit array we can contribute to
        nrows = len(table[list(self.field_data)[0]])
        int_matrix = np.zeros((nrows, self.ints_per_row), dtype=np.uint32)
        # For each row, or the right bits of the int values
        for column_name, field_data in self.field_data.items():
            column_value = table[column_name]
            if field_data.labels:
                # Choice, lookup indexes of the label values
                indexes = [field_data.labels.index(v) for v in column_value]
                column_value = np.array(indexes, dtype=np.uint32)
            else:
                # Array, unwrap to get the numpy array
                column_value = column_value.seq
            # Left shift the value so it is aligned with the int columns
            _, mask = get_nbits_mask(field_data)
            shifted_column = (column_value & mask) << field_data.bits_lo % 32
            # Or it with what we currently have
            column_index = get_column_index(field_data)
            int_matrix[..., column_index] |= shifted_column.astype(np.uint32)
        # Flatten it to a list of uints
        int_values = int_matrix.reshape((nrows * self.ints_per_row, ))
        return int_values

    def table_from_list(self, int_values):
        columns = {}
        nrows = len(int_values) // self.ints_per_row
        # Convert to a 1D uint32 array
        u32 = np.array([int(x) for x in int_values], dtype=np.uint32)
        # Reshape to a 2D array
        int_matrix = u32.reshape((nrows, self.ints_per_row))
        # Create the data for each column
        for column_name, field_data in self.field_data.items():
            # Find the right int to operate on
            column_index = get_column_index(field_data)
            int_column = int_matrix[..., column_index]
            # Right shift data, and mask it
            nbits, mask = get_nbits_mask(field_data)
            shifted_column = (int_column >> field_data.bits_lo % 32) & mask
            # If we wanted labels, convert to values here
            if field_data.labels:
                column_value = [field_data.labels[i] for i in shifted_column]
            elif nbits == 1:
                column_value = shifted_column.astype(np.bool)
            else:
                # View as the correct type
                dtype = self.meta.elements[column_name].dtype
                column_value = shifted_column.astype(dtype)
            columns[column_name] = column_value
        # Create a table from it
        table = self.meta.validate(self.meta.table_cls(**columns))
        return table
コード例 #4
0
class PandABoxTablePart(PandABoxFieldPart):
    """This will normally be instantiated by the PandABox assembly, not created
    in yaml"""

    def __init__(self, process, control, meta, block_name, field_name,
                 writeable):
        super(PandABoxTablePart, self).__init__(
            process, control, meta, block_name, field_name, writeable)
        # Fill in the meta object with the correct headers
        columns = OrderedDict()
        self.fields = OrderedDict()
        fields = control.get_table_fields(block_name, field_name)
        for field_name, (bits_hi, bits_lo) in fields.items():
            nbits = bits_hi - bits_lo + 1
            if nbits < 1:
                raise ValueError("Bad bits %s:%s" % (bits_hi, bits_lo))
            if nbits == 1:
                column_meta = BooleanArrayMeta(field_name)
                widget_tag = widget("checkbox")
            else:
                if nbits <= 8:
                    dtype = "uint8"
                elif nbits <= 16:
                    dtype = "uint16"
                elif nbits <= 32:
                    dtype = "uint32"
                elif nbits <= 64:
                    dtype = "uint64"
                else:
                    raise ValueError("Bad bits %s:%s" % (bits_hi, bits_lo))
                column_meta = NumberArrayMeta(dtype, field_name)
                widget_tag = widget("textinput")
            label, column_name = make_label_attr_name(field_name)
            column_meta.set_label(label)
            column_meta.set_tags([widget_tag])
            columns[column_name] = column_meta
            self.fields[column_name] = (bits_hi, bits_lo)
        meta.set_elements(TableElementMap(columns))

    def set_field(self, value):
        int_values = self.list_from_table(value)
        self.control.set_table(self.block_name, self.field_name, int_values)

    def _calc_nconsume(self):
        max_bits_hi = max(self.fields.values())[0]
        nconsume = int((max_bits_hi + 31) / 32)
        return nconsume

    def list_from_table(self, table):
        int_values = []
        if self.fields:
            nconsume = self._calc_nconsume()
            for row in range(len(table[list(self.fields)[0]])):
                int_value = 0
                for name, (bits_hi, bits_lo) in self.fields.items():
                    max_value = 2 ** (bits_hi - bits_lo + 1)
                    field_value = int(table[name][row])
                    assert field_value < max_value, \
                        "Expected %s[%d] < %s, got %s" % (
                            name, row, max_value, field_value)
                    int_value |= field_value << bits_lo
                # Split the big int into 32-bit numbers
                for i in range(nconsume):
                    int_values.append(int_value & (2 ** 32 - 1))
                    int_value = int_value >> 32
        return int_values

    def table_from_list(self, int_values):
        table = Table(self.meta)
        if self.fields:
            nconsume = self._calc_nconsume()

            for i in range(int(len(int_values) / nconsume)):
                int_value = 0
                for c in range(nconsume):
                    int_value += int(int_values[i*nconsume+c]) << (32 * c)
                row = []
                for name, (bits_hi, bits_lo) in self.fields.items():
                    mask = 2 ** (bits_hi + 1) - 1
                    field_value = (int_value & mask) >> bits_lo
                    row.append(field_value)
                table.append(row)
        return table
コード例 #5
0
class PandABlocksTablePart(PandABlocksFieldPart):
    """This will normally be instantiated by the PandABox assembly, not created
    in yaml"""
    def __init__(self, client, meta, block_name, field_name):
        # type: (AClient, AMeta, ABlockName, AFieldName) -> None
        # Fill in the meta object with the correct headers
        columns = OrderedDict()
        self.field_data = OrderedDict()
        fields = client.get_table_fields(block_name, field_name)
        if not fields:
            # Didn't put any metadata in, make some up
            fields["VALUE"] = TableFieldData(31, 0, "The Value", None)
        for column_name, field_data in fields.items():
            nbits = field_data.bits_hi - field_data.bits_lo + 1
            if nbits < 1:
                raise ValueError("Bad bits in %s" % (field_data, ))
            if field_data.labels:
                column_meta = ChoiceArrayMeta(choices=field_data.labels)
                widget = Widget.COMBO
            elif nbits == 1:
                column_meta = BooleanArrayMeta()
                widget = Widget.CHECKBOX
            else:
                if nbits <= 8:
                    dtype = "uint8"
                elif nbits <= 16:
                    dtype = "uint16"
                elif nbits <= 32:
                    dtype = "uint32"
                elif nbits <= 64:
                    dtype = "uint64"
                else:
                    raise ValueError("Bad bits in %s" % (field_data, ))
                column_meta = NumberArrayMeta(dtype)
                widget = Widget.TEXTINPUT
            column_name = snake_to_camel(column_name)
            column_meta.set_label(camel_to_title(column_name))
            column_meta.set_tags([widget.tag()])
            column_meta.set_description(field_data.description)
            column_meta.set_writeable(True)
            columns[column_name] = column_meta
            self.field_data[column_name] = field_data
        meta.set_elements(columns)
        # Superclass will make the attribute for us
        super(PandABlocksTablePart, self).__init__(client, meta, block_name,
                                                   field_name)

    def set_field(self, value):
        int_values = self.list_from_table(value)
        self.client.set_table(self.block_name, self.field_name, int_values)

    def _calc_nconsume(self):
        max_bits_hi = max(f.bits_hi for f in self.field_data.values())
        nconsume = int((max_bits_hi + 31) / 32)
        return nconsume

    def list_from_table(self, table):
        int_values = []
        nconsume = self._calc_nconsume()
        for row in table.rows():
            int_value = 0
            for name, value in zip(table.call_types, row):
                field_data = self.field_data[name]
                max_value = 2**(field_data.bits_hi - field_data.bits_lo + 1)
                if field_data.labels:
                    field_value = field_data.labels.index(value)
                else:
                    field_value = int(value)
                assert field_value < max_value, \
                    "Expected %s[%d] < %s, got %s" % (
                        name, row, max_value, field_value)
                int_value |= field_value << field_data.bits_lo
            # Split the big int into 32-bit numbers
            for i in range(nconsume):
                int_values.append(int_value & (2**32 - 1))
                int_value = int_value >> 32
        return int_values

    def table_from_list(self, int_values):
        rows = []
        nconsume = self._calc_nconsume()
        for i in range(int(len(int_values) / nconsume)):
            int_value = 0
            for c in range(nconsume):
                int_value += int(int_values[i * nconsume + c]) << (32 * c)
            row = []
            for name, field_data in self.field_data.items():
                mask = 2**(field_data.bits_hi + 1) - 1
                field_value = (int_value & mask) >> field_data.bits_lo
                if field_data.labels:
                    # This is a choice meta, so write the string value
                    row.append(field_data.labels[field_value])
                else:
                    row.append(field_value)
            rows.append(row)
        table = self.meta.table_cls.from_rows(rows)
        return table
コード例 #6
0
ファイル: process.py プロジェクト: dls-controls/pymalcolm
class Process(Loggable):
    """Hosts a number of Blocks, distributing requests between them"""

    def __init__(self, name, sync_factory):
        self.set_logger_name(name)
        self.name = name
        self.sync_factory = sync_factory
        self.q = self.create_queue()
        self._blocks = OrderedDict()  # block_name -> Block
        self._controllers = OrderedDict()  # block_name -> Controller
        self._block_state_cache = Cache()
        self._recv_spawned = None
        self._other_spawned = []
        # lookup of all Subscribe requests, ordered to guarantee subscription
        # notification ordering
        # {Request.generate_key(): Subscribe}
        self._subscriptions = OrderedDict()
        self.comms = []
        self._client_comms = OrderedDict()  # client comms -> list of blocks
        self._handle_functions = {
            Post: self._forward_block_request,
            Put: self._forward_block_request,
            Get: self._handle_get,
            Subscribe: self._handle_subscribe,
            Unsubscribe: self._handle_unsubscribe,
            BlockChanges: self._handle_block_changes,
            BlockRespond: self._handle_block_respond,
            BlockAdd: self._handle_block_add,
            BlockList: self._handle_block_list,
            AddSpawned: self._add_spawned,
        }
        self.create_process_block()

    def recv_loop(self):
        """Service self.q, distributing the requests to the right block"""
        while True:
            request = self.q.get()
            self.log_debug("Received request %s", request)
            if request is PROCESS_STOP:
                # Got the sentinel, stop immediately
                break
            try:
                self._handle_functions[type(request)](request)
            except Exception as e:  # pylint:disable=broad-except
                self.log_exception("Exception while handling %s", request)
                try:
                    request.respond_with_error(str(e))
                except Exception:
                    pass

    def add_comms(self, comms):
        assert not self._recv_spawned, \
            "Can't add comms when process has been started"
        self.comms.append(comms)

    def start(self):
        """Start the process going"""
        self._recv_spawned = self.sync_factory.spawn(self.recv_loop)
        for comms in self.comms:
            comms.start()

    def stop(self, timeout=None):
        """Stop the process and wait for it to finish

        Args:
            timeout (float): Maximum amount of time to wait for each spawned
                process. None means forever
        """
        assert self._recv_spawned, "Process not started"
        self.q.put(PROCESS_STOP)
        for comms in self.comms:
            comms.stop()
        # Wait for recv_loop to complete first
        self._recv_spawned.wait(timeout=timeout)
        # Now wait for anything it spawned to complete
        for s, _ in self._other_spawned:
            s.wait(timeout=timeout)
        # Garbage collect the syncfactory
        del self.sync_factory

    def _forward_block_request(self, request):
        """Lookup target Block and spawn block.handle_request(request)

        Args:
            request (Request): The message that should be passed to the Block
        """
        block_name = request.endpoint[0]
        block = self._blocks[block_name]
        spawned = self.sync_factory.spawn(block.handle_request, request)
        self._add_spawned(AddSpawned(spawned, block.handle_request))

    def create_queue(self):
        """
        Create a queue using sync_factory object

        Returns:
            Queue: New queue
        """

        return self.sync_factory.create_queue()

    def create_lock(self):
        """
        Create a lock using sync_factory object

        Returns:
            New lock object
        """
        return self.sync_factory.create_lock()

    def spawn(self, function, *args, **kwargs):
        """Calls SyncFactory.spawn()"""
        def catching_function():
            try:
                function(*args, **kwargs)
            except Exception:
                self.log_exception(
                    "Exception calling %s(*%s, **%s)", function, args, kwargs)
                raise
        spawned = self.sync_factory.spawn(catching_function)
        request = AddSpawned(spawned, function)
        self.q.put(request)
        return spawned

    def _add_spawned(self, request):
        spawned = self._other_spawned
        self._other_spawned = []
        spawned.append((request.spawned, request.function))
        # Filter out the spawned that have completed to stop memory leaks
        for sp, f in spawned:
            if not sp.ready():
                self._other_spawned.append((sp, f))

    def get_client_comms(self, block_name):
        for client_comms, blocks in list(self._client_comms.items()):
            if block_name in blocks:
                return client_comms

    def create_process_block(self):
        self.process_block = Block()
        # TODO: add a meta here
        children = OrderedDict()
        children["blocks"] = StringArrayMeta(
            description="Blocks hosted by this Process"
        ).make_attribute([])
        children["remoteBlocks"] = StringArrayMeta(
                description="Blocks reachable via ClientComms"
        ).make_attribute([])
        self.process_block.replace_endpoints(children)
        self.process_block.set_process_path(self, [self.name])
        self.add_block(self.process_block, self)

    def update_block_list(self, client_comms, blocks):
        self.q.put(BlockList(client_comms=client_comms, blocks=blocks))

    def _handle_block_list(self, request):
        self._client_comms[request.client_comms] = request.blocks
        remotes = []
        for blocks in self._client_comms.values():
            remotes += [b for b in blocks if b not in remotes]
        self.process_block["remoteBlocks"].set_value(remotes)

    def _handle_block_changes(self, request):
        """Update subscribers with changes and applies stored changes to the
        cached structure"""
        # update cached dict
        subscription_changes = self._block_state_cache.apply_changes(
            *request.changes)

        # Send out the changes
        for subscription, changes in subscription_changes.items():
            if subscription.delta:
                # respond with the filtered changes
                subscription.respond_with_delta(changes)
            else:
                # respond with the structure of everything
                # below the endpoint
                d = self._block_state_cache.walk_path(subscription.endpoint)
                subscription.respond_with_update(d)

    def report_changes(self, *changes):
        self.q.put(BlockChanges(changes=list(changes)))

    def block_respond(self, response, response_queue):
        self.q.put(BlockRespond(response, response_queue))

    def _handle_block_respond(self, request):
        """Push the response to the required queue"""
        request.response_queue.put(request.response)

    def add_block(self, block, controller):
        """Add a block to be hosted by this process

        Args:
            block (Block): The block to be added
            controller (Controller): Its controller
        """
        path = block.process_path
        assert len(path) == 1, \
            "Expected block %r to have %r as parent, got path %r" % \
            (block, self, path)
        name = path[0]
        assert name not in self._blocks, \
            "There is already a block called %r" % name
        request = BlockAdd(block=block, controller=controller, name=name)
        if self._recv_spawned:
            # Started, so call in Process thread
            self.q.put(request)
        else:
            # Not started yet so we are safe to add in this thread
            self._handle_block_add(request)

    def _handle_block_add(self, request):
        """Add a block to be hosted by this process"""
        assert request.name not in self._blocks, \
            "There is already a block called %r" % request.name
        self._blocks[request.name] = request.block
        self._controllers[request.name] = request.controller
        serialized = request.block.to_dict()
        change_request = BlockChanges([[[request.name], serialized]])
        self._handle_block_changes(change_request)
        # Regenerate list of blocks
        self.process_block["blocks"].set_value(list(self._blocks))

    def get_block(self, block_name):
        try:
            return self._blocks[block_name]
        except KeyError:
            if block_name in self.process_block.remoteBlocks:
                return self.make_client_block(block_name)
            else:
                raise

    def make_client_block(self, block_name):
        params = ClientController.MethodMeta.prepare_input_map(
            mri=block_name)
        controller = ClientController(self, {}, params)
        return controller.block

    def get_controller(self, block_name):
        return self._controllers[block_name]

    def _handle_subscribe(self, request):
        """Add a new subscriber and respond with the current
        sub-structure state"""
        key = request.generate_key()
        assert key not in self._subscriptions, \
            "Subscription on %s already exists" % (key,)
        self._subscriptions[key] = request
        self._block_state_cache.add_subscriber(request, request.endpoint)
        d = self._block_state_cache.walk_path(request.endpoint)
        self.log_debug("Initial subscription value %s", d)
        if request.delta:
            request.respond_with_delta([[[], d]])
        else:
            request.respond_with_update(d)

    def _handle_unsubscribe(self, request):
        """Remove a subscriber and respond with success or error"""
        key = request.generate_key()
        try:
            subscription = self._subscriptions.pop(key)
        except KeyError:
            request.respond_with_error(
                "No subscription found for %s" % (key,))
        else:
            self._block_state_cache.remove_subscriber(
                subscription, subscription.endpoint)
            request.respond_with_return()

    def _handle_get(self, request):
        d = self._block_state_cache.walk_path(request.endpoint)
        request.respond_with_return(d)
コード例 #7
0
ファイル: controller.py プロジェクト: nsob1c12/pymalcolm
class Controller(Loggable):
    use_cothread = True

    # Attributes
    health = None

    def __init__(self, process, mri, parts, description=""):
        super(Controller, self).__init__(mri=mri)
        self.process = process
        self.mri = mri
        self._request_queue = Queue()
        # {Part: Alarm} for current faults
        self._faults = {}
        # {Hook: name}
        self._hook_names = {}
        # {Hook: {Part: func_name}}
        self._hooked_func_names = {}
        self._find_hooks()
        # {part_name: (field_name, Model, setter)
        self.part_fields = OrderedDict()
        # {name: Part}
        self.parts = OrderedDict()
        self._lock = RLock(self.use_cothread)
        self._block = BlockModel()
        self._block.meta.set_description(description)
        self.set_label(mri)
        for part in parts:
            self.add_part(part)
        self._notifier = Notifier(mri, self._lock, self._block)
        self._block.set_notifier_path(self._notifier, [mri])
        self._write_functions = {}
        self._add_block_fields()

    def set_label(self, label):
        """Set the label of the Block Meta object"""
        self._block.meta.set_label(label)

    def add_part(self, part):
        assert part.name not in self.parts, \
            "Part %r already exists in Controller %r" % (part.name, self.mri)
        part.attach_to_controller(self)
        # Check part hooks into one of our hooks
        for func_name, part_hook, _ in get_hook_decorated(part):
            assert part_hook in self._hook_names, \
                "Part %s func %s not hooked into %s" % (
                    part.name, func_name, self)
            self._hooked_func_names[part_hook][part] = func_name
        part_fields = list(part.create_attribute_models()) + \
                      list(part.create_method_models())
        self.parts[part.name] = part
        self.part_fields[part.name] = part_fields

    def _find_hooks(self):
        for name, member in inspect.getmembers(self, Hook.isinstance):
            assert member not in self._hook_names, \
                "Hook %s already in %s as %s" % (
                    self, name, self._hook_names[member])
            self._hook_names[member] = name
            self._hooked_func_names[member] = {}

    def _add_block_fields(self):
        for iterable in (self.create_attribute_models(),
                         self.create_method_models(),
                         self.initial_part_fields()):
            for name, child, writeable_func in iterable:
                self.add_block_field(name, child, writeable_func)

    def add_block_field(self, name, child, writeable_func):
        if writeable_func:
            self._write_functions[name] = writeable_func
        if isinstance(child, AttributeModel):
            if writeable_func:
                child.meta.set_writeable(True)
            if not child.meta.label:
                child.meta.set_label(camel_to_title(name))
        elif isinstance(child, MethodModel):
            if writeable_func:
                child.set_writeable(True)
                for k, v in child.takes.elements.items():
                    v.set_writeable(True)
            if not child.label:
                child.set_label(camel_to_title(name))
        else:
            raise ValueError("Invalid block field %r" % child)
        self._block.set_endpoint_data(name, child)

    def create_method_models(self):
        """Provide MethodModel instances to be attached to BlockModel

        Yields:
            tuple: (string name, MethodModel, callable post_function).
        """
        return get_method_decorated(self)

    def create_attribute_models(self):
        """Provide AttributeModel instances to be attached to BlockModel

        Yields:
            tuple: (string name, AttributeModel, callable put_function).
        """
        # Create read-only attribute to show error texts
        meta = HealthMeta("Displays OK or an error message")
        self.health = meta.create_attribute_model()
        yield "health", self.health, None

    def initial_part_fields(self):
        for part_fields in self.part_fields.values():
            for data in part_fields:
                yield data

    def spawn(self, func, *args, **kwargs):
        """Spawn a function in the right thread"""
        spawned = self.process.spawn(func, args, kwargs, self.use_cothread)
        return spawned

    @property
    @contextmanager
    def lock_released(self):
        self._lock.release()
        try:
            yield
        finally:
            self._lock.acquire()

    @property
    def changes_squashed(self):
        return self._notifier.changes_squashed

    def update_health(self, part, alarm=None):
        """Set the health attribute. Called from part"""
        if alarm is not None:
            alarm = deserialize_object(alarm, Alarm)
        with self.changes_squashed:
            if alarm is None or not alarm.severity:
                self._faults.pop(part, None)
            else:
                self._faults[part] = alarm
            if self._faults:
                # Sort them by severity
                faults = sorted(self._faults.values(), key=lambda a: a.severity)
                alarm = faults[-1]
                text = faults[-1].message
            else:
                alarm = None
                text = "OK"
            self.health.set_value(text, alarm=alarm)

    def block_view(self):
        """Get a view of the block we control

        Returns:
            Block: The block we control
        """
        context = Context(self.process)
        return self.make_view(context)

    def make_view(self, context, data=None, child_name=None):
        """Make a child View of data[child_name]"""
        try:
            return self._make_view(context, data, child_name)
        except WrongThreadError:
            # called from wrong thread, spawn it again
            result = self.spawn(self._make_view, context, data, child_name)
            return result.get()

    def _make_view(self, context, data, child_name):
        """Called in cothread's thread"""
        with self._lock:
            if data is None:
                child = self._block
            else:
                child = data[child_name]
            child_view = self._make_appropriate_view(context, child)
        return child_view

    def _make_appropriate_view(self, context, data):
        if isinstance(data, BlockModel):
            # Make an Block View
            return make_block_view(self, context, data)
        elif isinstance(data, AttributeModel):
            # Make an Attribute View
            return Attribute(self, context, data)
        elif isinstance(data, MethodModel):
            # Make a Method View
            return Method(self, context, data)
        elif isinstance(data, Model):
            # Make a generic View of it
            return make_view(self, context, data)
        elif isinstance(data, dict):
            # Need to recurse down
            d = OrderedDict()
            for k, v in data.items():
                d[k] = self._make_appropriate_view(context, v)
            return d
        elif isinstance(data, list):
            # Need to recurse down
            return [self._make_appropriate_view(context, x) for x in data]
        else:
            return data

    def handle_request(self, request):
        """Spawn a new thread that handles Request"""
        # Put data on the queue, so if spawns are handled out of order we
        # still get the most up to date data
        self._request_queue.put(request)
        return self.spawn(self._handle_request)

    def _handle_request(self):
        responses = []
        with self._lock:
            # We spawned just above, so there is definitely something on the
            # queue
            request = self._request_queue.get(timeout=0)
            # self.log.debug(request)
            if isinstance(request, Get):
                handler = self._handle_get
            elif isinstance(request, Put):
                handler = self._handle_put
            elif isinstance(request, Post):
                handler = self._handle_post
            elif isinstance(request, Subscribe):
                handler = self._notifier.handle_subscribe
            elif isinstance(request, Unsubscribe):
                handler = self._notifier.handle_unsubscribe
            else:
                raise UnexpectedError("Unexpected request %s", request)
            try:
                responses += handler(request)
            except Exception as e:
                responses.append(request.error_response(e))
        for cb, response in responses:
            try:
                cb(response)
            except Exception as e:
                self.log.exception("Exception notifying %s", response)
                raise

    def _handle_get(self, request):
        """Called with the lock taken"""
        data = self._block
        for endpoint in request.path[1:]:
            try:
                data = data[endpoint]
            except KeyError:
                if hasattr(data, "typeid"):
                    typ = data.typeid
                else:
                    typ = type(data)
                raise UnexpectedError(
                    "Object of type %r has no attribute %r" % (typ, endpoint))
        serialized = serialize_object(data)
        ret = [request.return_response(serialized)]
        return ret

    def _handle_put(self, request):
        """Called with the lock taken"""
        attribute_name = request.path[1]

        attribute = self._block[attribute_name]
        assert attribute.meta.writeable, \
            "Attribute %s is not writeable" % attribute_name
        put_function = self._write_functions[attribute_name]

        with self.lock_released:
            result = put_function(request.value)

        ret = [request.return_response(result)]
        return ret

    def _handle_post(self, request):
        """Called with the lock taken"""
        method_name = request.path[1]
        if request.parameters:
            param_dict = request.parameters
        else:
            param_dict = {}

        method = self._block[method_name]
        assert method.writeable, \
            "Method %s is not writeable" % method_name
        args = method.prepare_call_args(**param_dict)
        post_function = self._write_functions[method_name]

        with self.lock_released:
            result = post_function(*args)

        result = self.validate_result(method_name, result)
        ret = [request.return_response(result)]
        return ret

    def validate_result(self, method_name, result):
        with self._lock:
            method = self._block[method_name]
            # Prepare output map
            if method.returns.elements:
                result = Map(method.returns, result)
                result.check_valid()
        return result

    def create_part_contexts(self):
        part_contexts = {}
        for part_name, part in self.parts.items():
            part_contexts[part] = Context(self.process)
        return part_contexts

    def run_hook(self, hook, part_contexts, *args, **params):
        hook_queue, hook_runners = self.start_hook(
            hook, part_contexts, *args, **params)
        return_dict = self.wait_hook(hook_queue, hook_runners)
        return return_dict

    def start_hook(self, hook, part_contexts, *args, **params):
        assert hook in self._hook_names, \
            "Hook %s doesn't appear in controller hooks %s" % (
                hook, self._hook_names)
        hook_name = self._hook_names[hook]
        self.log.debug("%s: Starting hook", hook_name)

        # This queue will hold (part, result) tuples
        hook_queue = Queue()
        hook_queue.hook_name = hook_name
        hook_runners = {}

        # now start them off
        # Take the lock so that no hook abort can come in between now and
        # the spawn of the context
        with self._lock:
            for part, context in part_contexts.items():
                # context might have been aborted but have nothing servicing
                # the queue, we still want the legitimate messages on the queue
                # so just tell it to ignore stops it got before now
                context.ignore_stops_before_now()
                func_name = self._hooked_func_names[hook].get(part, None)
                if func_name:
                    hook_runners[part] = part.make_hook_runner(
                        hook_queue, func_name, weakref.proxy(context), *args,
                        **params)

        return hook_queue, hook_runners

    def wait_hook(self, hook_queue, hook_runners):
        # Wait for them all to finish
        return_dict = {}
        start = time.time()
        while hook_runners:
            part, ret = hook_queue.get()
            hook_runner = hook_runners.pop(part)

            # Wait for the process to terminate
            hook_runner.wait()
            return_dict[part.name] = ret
            duration = time.time() - start
            if hook_runners:
                self.log.debug(
                    "%s: Part %s returned %r after %ss. Still waiting for %s",
                    hook_queue.hook_name, part.name, ret, duration,
                    [p.name for p in hook_runners])
            else:
                self.log.debug(
                    "%s: Part %s returned %r after %ss. Returning...",
                    hook_queue.hook_name, part.name, ret, duration)

            if isinstance(ret, Exception):
                if not isinstance(ret, AbortedError):
                    # If AbortedError, all tasks have already been stopped.
                    # Got an error, so stop and wait all hook runners
                    for h in hook_runners.values():
                        h.stop()
                # Wait for them to finish
                for h in hook_runners.values():
                    h.wait(timeout=ABORT_TIMEOUT)
                raise ret

        return return_dict
コード例 #8
0
ファイル: process.py プロジェクト: shroffk/pymalcolm
class Process(Loggable):
    """Hosts a number of Blocks, distributing requests between them"""
    def __init__(self, name, sync_factory):
        self.set_logger_name(name)
        self.name = name
        self.sync_factory = sync_factory
        self.q = self.create_queue()
        self._blocks = OrderedDict()  # block_name -> Block
        self._controllers = OrderedDict()  # block_name -> Controller
        self._block_state_cache = Cache()
        self._recv_spawned = None
        self._other_spawned = []
        # lookup of all Subscribe requests, ordered to guarantee subscription
        # notification ordering
        # {Request.generate_key(): Subscribe}
        self._subscriptions = OrderedDict()
        self.comms = []
        self._client_comms = OrderedDict()  # client comms -> list of blocks
        self._handle_functions = {
            Post: self._forward_block_request,
            Put: self._forward_block_request,
            Get: self._handle_get,
            Subscribe: self._handle_subscribe,
            Unsubscribe: self._handle_unsubscribe,
            BlockChanges: self._handle_block_changes,
            BlockRespond: self._handle_block_respond,
            BlockAdd: self._handle_block_add,
            BlockList: self._handle_block_list,
            AddSpawned: self._add_spawned,
        }
        self.create_process_block()

    def recv_loop(self):
        """Service self.q, distributing the requests to the right block"""
        while True:
            request = self.q.get()
            self.log_debug("Received request %s", request)
            if request is PROCESS_STOP:
                # Got the sentinel, stop immediately
                break
            try:
                self._handle_functions[type(request)](request)
            except Exception as e:  # pylint:disable=broad-except
                self.log_exception("Exception while handling %s", request)
                try:
                    request.respond_with_error(str(e))
                except Exception:
                    pass

    def add_comms(self, comms):
        assert not self._recv_spawned, \
            "Can't add comms when process has been started"
        self.comms.append(comms)

    def start(self):
        """Start the process going"""
        self._recv_spawned = self.sync_factory.spawn(self.recv_loop)
        for comms in self.comms:
            comms.start()

    def stop(self, timeout=None):
        """Stop the process and wait for it to finish

        Args:
            timeout (float): Maximum amount of time to wait for each spawned
                process. None means forever
        """
        assert self._recv_spawned, "Process not started"
        self.q.put(PROCESS_STOP)
        for comms in self.comms:
            comms.stop()
        # Wait for recv_loop to complete first
        self._recv_spawned.wait(timeout=timeout)
        # Now wait for anything it spawned to complete
        for s, _ in self._other_spawned:
            s.wait(timeout=timeout)
        # Garbage collect the syncfactory
        del self.sync_factory

    def _forward_block_request(self, request):
        """Lookup target Block and spawn block.handle_request(request)

        Args:
            request (Request): The message that should be passed to the Block
        """
        block_name = request.endpoint[0]
        block = self._blocks[block_name]
        spawned = self.sync_factory.spawn(block.handle_request, request)
        self._add_spawned(AddSpawned(spawned, block.handle_request))

    def create_queue(self):
        """
        Create a queue using sync_factory object

        Returns:
            Queue: New queue
        """

        return self.sync_factory.create_queue()

    def create_lock(self):
        """
        Create a lock using sync_factory object

        Returns:
            New lock object
        """
        return self.sync_factory.create_lock()

    def spawn(self, function, *args, **kwargs):
        """Calls SyncFactory.spawn()"""
        def catching_function():
            try:
                function(*args, **kwargs)
            except Exception:
                self.log_exception("Exception calling %s(*%s, **%s)", function,
                                   args, kwargs)
                raise

        spawned = self.sync_factory.spawn(catching_function)
        request = AddSpawned(spawned, function)
        self.q.put(request)
        return spawned

    def _add_spawned(self, request):
        spawned = self._other_spawned
        self._other_spawned = []
        spawned.append((request.spawned, request.function))
        # Filter out the spawned that have completed to stop memory leaks
        for sp, f in spawned:
            if not sp.ready():
                self._other_spawned.append((sp, f))

    def get_client_comms(self, block_name):
        for client_comms, blocks in list(self._client_comms.items()):
            if block_name in blocks:
                return client_comms

    def create_process_block(self):
        self.process_block = Block()
        # TODO: add a meta here
        children = OrderedDict()
        children["blocks"] = StringArrayMeta(
            description="Blocks hosted by this Process").make_attribute([])
        children["remoteBlocks"] = StringArrayMeta(
            description="Blocks reachable via ClientComms").make_attribute([])
        self.process_block.replace_endpoints(children)
        self.process_block.set_process_path(self, [self.name])
        self.add_block(self.process_block, self)

    def update_block_list(self, client_comms, blocks):
        self.q.put(BlockList(client_comms=client_comms, blocks=blocks))

    def _handle_block_list(self, request):
        self._client_comms[request.client_comms] = request.blocks
        remotes = []
        for blocks in self._client_comms.values():
            remotes += [b for b in blocks if b not in remotes]
        self.process_block["remoteBlocks"].set_value(remotes)

    def _handle_block_changes(self, request):
        """Update subscribers with changes and applies stored changes to the
        cached structure"""
        # update cached dict
        subscription_changes = self._block_state_cache.apply_changes(
            *request.changes)

        # Send out the changes
        for subscription, changes in subscription_changes.items():
            if subscription.delta:
                # respond with the filtered changes
                subscription.respond_with_delta(changes)
            else:
                # respond with the structure of everything
                # below the endpoint
                d = self._block_state_cache.walk_path(subscription.endpoint)
                subscription.respond_with_update(d)

    def report_changes(self, *changes):
        self.q.put(BlockChanges(changes=list(changes)))

    def block_respond(self, response, response_queue):
        self.q.put(BlockRespond(response, response_queue))

    def _handle_block_respond(self, request):
        """Push the response to the required queue"""
        request.response_queue.put(request.response)

    def add_block(self, block, controller):
        """Add a block to be hosted by this process

        Args:
            block (Block): The block to be added
            controller (Controller): Its controller
        """
        path = block.process_path
        assert len(path) == 1, \
            "Expected block %r to have %r as parent, got path %r" % \
            (block, self, path)
        name = path[0]
        assert name not in self._blocks, \
            "There is already a block called %r" % name
        request = BlockAdd(block=block, controller=controller, name=name)
        if self._recv_spawned:
            # Started, so call in Process thread
            self.q.put(request)
        else:
            # Not started yet so we are safe to add in this thread
            self._handle_block_add(request)

    def _handle_block_add(self, request):
        """Add a block to be hosted by this process"""
        assert request.name not in self._blocks, \
            "There is already a block called %r" % request.name
        self._blocks[request.name] = request.block
        self._controllers[request.name] = request.controller
        serialized = request.block.to_dict()
        change_request = BlockChanges([[[request.name], serialized]])
        self._handle_block_changes(change_request)
        # Regenerate list of blocks
        self.process_block["blocks"].set_value(list(self._blocks))

    def get_block(self, block_name):
        try:
            return self._blocks[block_name]
        except KeyError:
            if block_name in self.process_block.remoteBlocks:
                return self.make_client_block(block_name)
            else:
                raise

    def make_client_block(self, block_name):
        params = ClientController.MethodMeta.prepare_input_map(mri=block_name)
        controller = ClientController(self, {}, params)
        return controller.block

    def get_controller(self, block_name):
        return self._controllers[block_name]

    def _handle_subscribe(self, request):
        """Add a new subscriber and respond with the current
        sub-structure state"""
        key = request.generate_key()
        assert key not in self._subscriptions, \
            "Subscription on %s already exists" % (key,)
        self._subscriptions[key] = request
        self._block_state_cache.add_subscriber(request, request.endpoint)
        d = self._block_state_cache.walk_path(request.endpoint)
        self.log_debug("Initial subscription value %s", d)
        if request.delta:
            request.respond_with_delta([[[], d]])
        else:
            request.respond_with_update(d)

    def _handle_unsubscribe(self, request):
        """Remove a subscriber and respond with success or error"""
        key = request.generate_key()
        try:
            subscription = self._subscriptions.pop(key)
        except KeyError:
            request.respond_with_error("No subscription found for %s" %
                                       (key, ))
        else:
            self._block_state_cache.remove_subscriber(subscription,
                                                      subscription.endpoint)
            request.respond_with_return()

    def _handle_get(self, request):
        d = self._block_state_cache.walk_path(request.endpoint)
        request.respond_with_return(d)