Example #1
0
 def test_add_block_calls_handle(self):
     s = SyncFactory("sched")
     p = Process("proc", s)
     b = Block()
     b.set_parent(p, "myblock")
     p.add_block(b)
     p.start()
     p.stop()
     self.assertEqual(len(p._blocks), 2)
     self.assertEqual(p._blocks, dict(myblock=b, proc=p.process_block))
Example #2
0
class TestHandleRequest(unittest.TestCase):

    def setUp(self):
        self.block = Block()
        self.block.set_parent(MagicMock(), "TestBlock")
        self.method = MagicMock()
        self.attribute = MagicMock()
        self.response = MagicMock()
        self.block.add_method('get_things', self.method)
        self.block.add_attribute('test_attribute', self.attribute)

    def test_given_request_then_pass_to_correct_method(self):
        endpoint = ["TestBlock", "get_things"]
        request = Post(MagicMock(), MagicMock(), endpoint)

        self.block.handle_request(request)

        self.method.get_response.assert_called_once_with(request)
        response = self.method.get_response.return_value
        self.block.parent.block_respond.assert_called_once_with(
            response, request.response_queue)

    def test_given_put_then_update_attribute(self):
        endpoint = ["TestBlock", "test_attribute", "value"]
        value = "5"
        request = Put(MagicMock(), MagicMock(), endpoint, value)

        self.block.handle_request(request)

        self.attribute.put.assert_called_once_with(value)
        self.attribute.set_value.assert_called_once_with(value)
        response = self.block.parent.block_respond.call_args[0][0]
        self.assertEqual("malcolm:core/Return:1.0", response.typeid)
        self.assertIsNone(response.value)
        response_queue = self.block.parent.block_respond.call_args[0][1]
        self.assertEqual(request.response_queue, response_queue)

    def test_invalid_request_fails(self):
        request = MagicMock()
        request.type_ = "Get"

        self.assertRaises(AssertionError, self.block.handle_request, request)

    def test_invalid_request_fails(self):
        endpoint = ["a","b","c","d"]
        request = Post(MagicMock(), MagicMock(), endpoint)
        self.assertRaises(ValueError, self.block.handle_request, request)

        request = Put(MagicMock(), MagicMock(), endpoint)
        self.assertRaises(ValueError, self.block.handle_request, request)
Example #3
0
class TestHandleRequest(unittest.TestCase):
    def setUp(self):
        self.block = Block()
        self.block.set_parent(MagicMock(), "TestBlock")
        self.method = MagicMock()
        self.attribute = MagicMock()
        self.response = MagicMock()
        self.block.add_method('get_things', self.method)
        self.block.add_attribute('test_attribute', self.attribute)

    def test_given_request_then_pass_to_correct_method(self):
        endpoint = ["TestBlock", "get_things"]
        request = Post(MagicMock(), MagicMock(), endpoint)

        self.block.handle_request(request)

        self.method.get_response.assert_called_once_with(request)
        response = self.method.get_response.return_value
        self.block.parent.block_respond.assert_called_once_with(
            response, request.response_queue)

    def test_given_put_then_update_attribute(self):
        endpoint = ["TestBlock", "test_attribute", "value"]
        value = "5"
        request = Put(MagicMock(), MagicMock(), endpoint, value)

        self.block.handle_request(request)

        self.attribute.put.assert_called_once_with(value)
        self.attribute.set_value.assert_called_once_with(value)
        response = self.block.parent.block_respond.call_args[0][0]
        self.assertEqual("malcolm:core/Return:1.0", response.typeid)
        self.assertIsNone(response.value)
        response_queue = self.block.parent.block_respond.call_args[0][1]
        self.assertEqual(request.response_queue, response_queue)

    def test_invalid_request_fails(self):
        request = MagicMock()
        request.type_ = "Get"

        self.assertRaises(AssertionError, self.block.handle_request, request)

    def test_invalid_request_fails(self):
        endpoint = ["a", "b", "c", "d"]
        request = Post(MagicMock(), MagicMock(), endpoint)
        self.assertRaises(ValueError, self.block.handle_request, request)

        request = Put(MagicMock(), MagicMock(), endpoint)
        self.assertRaises(ValueError, self.block.handle_request, request)
Example #4
0
    def test_returns_dict(self):
        method_dict = OrderedDict(takes=OrderedDict(one=OrderedDict()),
                                  returns=OrderedDict(one=OrderedDict()),
                                  defaults=OrderedDict())

        m1 = MagicMock()
        m1.to_dict.return_value = method_dict

        m2 = MagicMock()
        m2.to_dict.return_value = method_dict

        a1 = MagicMock()
        a1dict = OrderedDict(value="test", meta=MagicMock())
        a1.to_dict.return_value = a1dict

        a2 = MagicMock()
        a2dict = OrderedDict(value="value", meta=MagicMock())
        a2.to_dict.return_value = a2dict

        block = Block()
        block.set_parent(MagicMock(), "Test")
        block.add_method('method_one', m1)
        block.add_method('method_two', m2)
        block.add_attribute('attr_one', a1)
        block.add_attribute('attr_two', a2)

        m1.reset_mock()
        m2.reset_mock()
        a1.reset_mock()
        a2.reset_mock()

        expected_dict = OrderedDict()
        expected_dict['typeid'] = "malcolm:core/Block:1.0"
        expected_dict['attr_one'] = a1dict
        expected_dict['attr_two'] = a2dict
        expected_dict['method_one'] = method_dict
        expected_dict['method_two'] = method_dict

        response = block.to_dict()

        m1.to_dict.assert_called_once_with()
        m2.to_dict.assert_called_once_with()
        self.assertEqual(expected_dict, response)
Example #5
0
    def test_returns_dict(self):
        method_dict = OrderedDict(takes=OrderedDict(one=OrderedDict()),
                                  returns=OrderedDict(one=OrderedDict()),
                                  defaults=OrderedDict())

        m1 = MagicMock()
        m1.to_dict.return_value = method_dict

        m2 = MagicMock()
        m2.to_dict.return_value = method_dict

        a1 = MagicMock()
        a1dict = OrderedDict(value="test", meta=MagicMock())
        a1.to_dict.return_value = a1dict

        a2 = MagicMock()
        a2dict = OrderedDict(value="value", meta=MagicMock())
        a2.to_dict.return_value = a2dict

        block = Block()
        block.set_parent(MagicMock(), "Test")
        block.add_method('method_one', m1)
        block.add_method('method_two', m2)
        block.add_attribute('attr_one', a1)
        block.add_attribute('attr_two', a2)

        m1.reset_mock()
        m2.reset_mock()
        a1.reset_mock()
        a2.reset_mock()

        expected_dict = OrderedDict()
        expected_dict['typeid'] = "malcolm:core/Block:1.0"
        expected_dict['attr_one'] = a1dict
        expected_dict['attr_two'] = a2dict
        expected_dict['method_one'] = method_dict
        expected_dict['method_two'] = method_dict

        response = block.to_dict()

        m1.to_dict.assert_called_once_with()
        m2.to_dict.assert_called_once_with()
        self.assertEqual(expected_dict, response)
class Controller(Loggable):
    """Implement the logic that takes a Block through its state machine"""

    # Attributes for all controllers
    state = None
    status = None
    busy = None
    # BlockMeta for descriptions
    meta = None

    def __init__(self, block_name, process, parts=None, params=None):
        """
        Args:
            process (Process): The process this should run under
        """
        controller_name = "%s(%s)" % (type(self).__name__, block_name)
        self.set_logger_name(controller_name)
        self.block = Block()
        self.log_debug("Creating block %r as %r" % (self.block, block_name))
        self.block_name = block_name
        self.params = params
        self.process = process
        self.lock = process.create_lock()
        # {part: task}
        self.part_tasks = {}
        # dictionary of dictionaries
        # {state (str): {MethodMeta: writeable (bool)}
        self.methods_writeable = {}
        # dict {hook: name}
        self.hook_names = self._find_hooks()
        self.parts = self._setup_parts(parts, controller_name)
        self._set_block_children()
        self._do_transition(sm.DISABLED, "Disabled")
        self.block.set_parent(process, block_name)
        process.add_block(self.block)
        self.do_initial_reset()

    def _find_hooks(self):
        hook_names = {}
        for n in dir(self):
            attr = getattr(self, n)
            if isinstance(attr, Hook):
                assert attr not in hook_names, \
                    "Hook %s already in controller as %s" % (
                        n, hook_names[attr])
                hook_names[attr] = n
        return hook_names

    def _setup_parts(self, parts, controller_name):
        if parts is None:
            parts = {}
        for part_name, part in parts.items():
            part.set_logger_name("%s.%s" % (controller_name, part_name))
            # Check part hooks into one of our hooks
            for func_name, part_hook, _ in get_hook_decorated(part):
                assert part_hook in self.hook_names, \
                    "Part %s func %s not hooked into %s" % (
                        part, func_name, self)
        return parts

    def do_initial_reset(self):
        request = Post(None, self.process.create_queue(),
                       [self.block_name, "reset"])
        self.process.q.put(request)

    def add_change(self, changes, item, attr, value):
        path = item.path_relative_to(self.block) + [attr]
        changes.append([path, value])

    def _set_block_children(self):
        # reconfigure block with new children
        child_list = [self.create_meta()]
        child_list += list(self._create_default_attributes())
        child_list += list(self.create_attributes())
        child_list += list(self.create_methods())
        for part in self.parts.values():
            child_list += list(part.create_attributes())
            child_list += list(part.create_methods())

        self.methods_writeable = {}
        writeable_functions = {}
        children = OrderedDict()

        for name, child, writeable_func in child_list:
            if isinstance(child, MethodMeta):
                # Set if the method is writeable
                if child.only_in is None:
                    states = [
                        state for state in self.stateMachine.possible_states
                        if state not in (sm.DISABLING, sm.DISABLED)
                    ]
                else:
                    states = child.only_in
                    for state in states:
                        assert state in self.stateMachine.possible_states, \
                            "State %s is not one of the valid states %s" % \
                            (state, self.stateMachine.possible_states)
                # Make a copy otherwise all instances will own the same one
                child = MethodMeta.from_dict(child.to_dict())
                self.register_method_writeable(child, states)
            elif isinstance(child, Attribute):
                child.meta.set_writeable(writeable_func is not None)
            children[name] = child
            if writeable_func:
                writeable_functions[name] = functools.partial(
                    self.call_writeable_function, writeable_func)

        self.block.replace_endpoints(children)
        self.block.set_writeable_functions(writeable_functions)

    def call_writeable_function(self, function, child, *args):
        with self.lock:
            if not child.writeable:
                child.log_error("I'm not writeable")
                raise ValueError("Child %r is not writeable" % (child, ))
        result = function(*args)
        return result

    def _create_default_attributes(self):
        # Add the state, status and busy attributes
        self.state = ChoiceMeta("State of Block",
                                self.stateMachine.possible_states,
                                label="State").make_attribute()
        yield "state", self.state, None
        self.status = StringMeta("Status of Block",
                                 label="Status").make_attribute()
        yield "status", self.status, None
        self.busy = BooleanMeta("Whether Block busy or not",
                                label="Busy").make_attribute()
        yield "busy", self.busy, None

    def create_meta(self):
        self.meta = BlockMeta()
        return "meta", self.meta, None

    def create_attributes(self):
        """Method that should provide Attribute instances for Block

        Yields:
            tuple: (string name, Attribute, callable put_function).
        """
        return iter(())

    def create_methods(self):
        """Method that should provide MethodMeta instances for Block

        Yields:
            tuple: (string name, MethodMeta, callable post_function).
        """
        return get_method_decorated(self)

    def transition(self, state, message, create_tasks=False):
        """
        Change to a new state if the transition is allowed

        Args:
            state(str): State to transition to
            message(str): Status message
            create_tasks(bool): If true then make self.part_tasks
        """
        with self.lock:
            if self.stateMachine.is_allowed(initial_state=self.state.value,
                                            target_state=state):
                self._do_transition(state, message)
                if create_tasks:
                    self.part_tasks = self.create_part_tasks()
            else:
                raise TypeError("Cannot transition from %s to %s" %
                                (self.state.value, state))

    def _do_transition(self, state, message):
        # transition is allowed, so set attributes
        changes = []
        self.add_change(changes, self.state, "value", state)
        self.add_change(changes, self.status, "value", message)
        self.add_change(changes, self.busy, "value", state
                        in self.stateMachine.busy_states)

        # say which methods can now be called
        for name in self.block:
            child = self.block[name]
            if isinstance(child, MethodMeta):
                method = child
                writeable = self.methods_writeable[state][method]
                self.add_change(changes, method, "writeable", writeable)
                for ename in method.takes.elements:
                    meta = method.takes.elements[ename]
                    self.add_change(changes, meta, "writeable", writeable)

        self.log_debug("Transitioning to %s", state)
        self.block.apply_changes(*changes)

    def register_method_writeable(self, method, states):
        """
        Set the states that the given method can be called in

        Args:
            method(MethodMeta): Method that will be set writeable or not
            states(list[str]): List of states where method is writeable
        """
        for state in self.stateMachine.possible_states:
            writeable_dict = self.methods_writeable.setdefault(state, {})
            is_writeable = state in states
            writeable_dict[method] = is_writeable

    def create_part_tasks(self):
        part_tasks = {}
        for part_name, part in self.parts.items():
            part_tasks[part] = Task("Task(%s)" % part_name, self.process)
        return part_tasks

    def run_hook(self, hook, part_tasks, **kwargs):
        hook_queue, func_tasks, task_part_names = self.start_hook(
            hook, part_tasks, **kwargs)
        return_table = hook.make_return_table(part_tasks)
        return_dict = self.wait_hook(hook_queue, func_tasks, task_part_names)
        for part_name in self.parts:
            return_map = return_dict.get(part_name, None)
            if return_map:
                self.fill_in_table(part_name, return_table, return_map)
        return return_table

    def fill_in_table(self, part_name, table, return_map):
        # Find all the array columns
        arrays = {}
        for column_name in table.meta.elements:
            meta = table.meta.elements[column_name]
            if "hook:return_array" in meta.tags:
                arrays[column_name] = return_map[column_name]
        # If there are any arrays, make sure they are the right length
        lengths = set(len(arr) for arr in arrays.values())
        if len(lengths) == 0:
            # no arrays
            iterations = 1
        else:
            assert len(lengths) == 1, \
                "Varying array length %s for rows %s" % (lengths, arrays)
            iterations = lengths.pop()
        for i in range(iterations):
            row = []
            for k in table.endpoints:
                if k == "name":
                    row.append(part_name)
                elif k in arrays:
                    row.append(arrays[k][i])
                else:
                    row.append(return_map[k])
            table.append(row)

    def make_task_return_value_function(self, hook_queue, **kwargs):
        def task_return(func, task):
            try:
                result = func.MethodMeta.call_post_function(func, kwargs, task)
            except StopIteration as e:
                self.log_debug("%s has been aborted", func)
                result = e
            except Exception as e:  # pylint:disable=broad-except
                self.log_exception("%s %s raised exception", func, kwargs)
                result = e
            self.log_debug("Putting %r on queue", result)
            hook_queue.put((func, result))

        return task_return

    def start_hook(self, hook, part_tasks, **kwargs):
        assert hook in self.hook_names, \
            "Hook %s doesn't appear in controller hooks %s" % (
                hook, self.hook_names)
        self.log_debug("Running %s hook", self.hook_names[hook])

        # ask the hook to find the functions it should run
        func_tasks = hook.find_func_tasks(part_tasks)

        # now start them off
        hook_queue = self.process.create_queue()
        task_return = self.make_task_return_value_function(
            hook_queue, **kwargs)

        for func, task in func_tasks.items():
            task.define_spawn_function(task_return, func, task)
            self.log_debug("Starting task %r", task)
            task.start()

        # Create the reverse dictionary so we know where to store the results
        task_part_names = {}
        for part_name, part in self.parts.items():
            if part in part_tasks:
                task_part_names[part_tasks[part]] = part_name

        return hook_queue, func_tasks, task_part_names

    def wait_hook(self, hook_queue, func_tasks, task_part_names):
        # Wait for them all to finish
        return_dict = {}
        while func_tasks:
            func, ret = hook_queue.get()
            task = func_tasks.pop(func)
            part_name = task_part_names[task]
            return_dict[part_name] = ret

            if isinstance(ret, Exception):
                # Stop all other tasks
                for task in func_tasks.values():
                    task.stop()
                for task in func_tasks.values():
                    task.wait()

            # If we got a StopIteration, someone asked us to stop, so
            # don't wait, otherwise make sure we finished
            if not isinstance(ret, StopIteration):
                task.wait()

            if isinstance(ret, Exception):
                raise ret

        return return_dict
class Process(Loggable):
    """Hosts a number of Blocks, distributing requests between them"""

    def __init__(self, name, sync_factory):
        self.set_logger_name(name)
        self.name = name
        self.sync_factory = sync_factory
        self.q = self.create_queue()
        self._blocks = OrderedDict()  # block name -> block
        self._block_state_cache = Cache()
        self._recv_spawned = None
        self._other_spawned = []
        self._subscriptions = []
        self.comms = []
        self._client_comms = OrderedDict()  # client comms -> list of blocks
        self._handle_functions = {
            Post: self._forward_block_request,
            Put: self._forward_block_request,
            Get: self._handle_get,
            Subscribe: self._handle_subscribe,
            Unsubscribe: self._handle_unsubscribe,
            BlockChanges: self._handle_block_changes,
            BlockRespond: self._handle_block_respond,
            BlockAdd: self._handle_block_add,
            BlockList: self._handle_block_list,
        }
        self.create_process_block()

    def recv_loop(self):
        """Service self.q, distributing the requests to the right block"""
        while True:
            request = self.q.get()
            self.log_debug("Received request %s", request)
            if request is PROCESS_STOP:
                # Got the sentinel, stop immediately
                break
            try:
                self._handle_functions[type(request)](request)
            except Exception as e:  # pylint:disable=broad-except
                self.log_exception("Exception while handling %s", request)
                request.respond_with_error(str(e))

    def add_comms(self, comms):
        assert not self._recv_spawned, \
            "Can't add comms when process has been started"
        self.comms.append(comms)

    def start(self):
        """Start the process going"""
        self._recv_spawned = self.sync_factory.spawn(self.recv_loop)
        for comms in self.comms:
            comms.start()

    def stop(self, timeout=None):
        """Stop the process and wait for it to finish

        Args:
            timeout (float): Maximum amount of time to wait for each spawned
            process. None means forever
        """
        assert self._recv_spawned, "Process not started"
        self.q.put(PROCESS_STOP)
        for comms in self.comms:
            comms.stop()
        # Wait for recv_loop to complete first
        self._recv_spawned.wait(timeout=timeout)
        # Now wait for anything it spawned to complete
        for s in self._other_spawned:
            s.wait(timeout=timeout)

    def _forward_block_request(self, request):
        """Lookup target Block and spawn block.handle_request(request)

        Args:
            request (Request): The message that should be passed to the Block
        """
        block_name = request.endpoint[0]
        block = self._blocks[block_name]
        self._other_spawned.append(
            self.sync_factory.spawn(block.handle_request, request))

    def create_queue(self):
        """
        Create a queue using sync_factory object

        Returns:
            Queue: New queue
        """

        return self.sync_factory.create_queue()

    def create_lock(self):
        """
        Create a lock using sync_factory object

        Returns:
            Lock: New lock
        """
        return self.sync_factory.create_lock()

    def spawn(self, function, *args, **kwargs):
        """Calls SyncFactory.spawn()"""
        spawned = self.sync_factory.spawn(function, *args, **kwargs)
        self._other_spawned.append(spawned)
        return spawned

    def get_client_comms(self, block_name):
        for client_comms, blocks in list(self._client_comms.items()):
            if block_name in blocks:
                return client_comms

    def create_process_block(self):
        self.process_block = Block()
        # TODO: add a meta here
        children = OrderedDict()
        children["blocks"] = StringArrayMeta(
            description="Blocks hosted by this Process"
        ).make_attribute([])
        children["remoteBlocks"] = StringArrayMeta(
                description="Blocks reachable via ClientComms"
        ).make_attribute([])
        self.process_block.replace_endpoints(children)
        self.process_block.set_parent(self, self.name)
        self.add_block(self.process_block)

    def update_block_list(self, client_comms, blocks):
        self.q.put(BlockList(client_comms=client_comms, blocks=blocks))

    def _handle_block_list(self, request):
        self._client_comms[request.client_comms] = request.blocks
        remotes = []
        for blocks in self._client_comms.values():
            remotes += [b for b in blocks if b not in remotes]
        self.process_block["remoteBlocks"].set_value(remotes)

    def _handle_block_changes(self, request):
        """Update subscribers with changes and applies stored changes to the
        cached structure"""
        # update cached dict
        self._block_state_cache.apply_changes(*request.changes)

        for subscription in self._subscriptions:
            endpoint = subscription.endpoint
            # find stuff that's changed that is relevant to this subscriber
            changes = []
            for change in request.changes:
                change_path = change[0]
                # look for a change_path where the beginning matches the
                # endpoint path, then strip away the matching part and add
                # to the change set
                i = 0
                for (cp_element, ep_element) in zip(change_path, endpoint):
                    if cp_element != ep_element:
                        break
                    i += 1
                else:
                    # change has matching path, so keep it
                    # but strip off the end point path
                    filtered_change = [change_path[i:]] + change[1:]
                    changes.append(filtered_change)
            if len(changes) > 0:
                if subscription.delta:
                    # respond with the filtered changes
                    subscription.respond_with_delta(changes)
                else:
                    # respond with the structure of everything
                    # below the endpoint
                    d = self._block_state_cache.walk_path(endpoint)
                    subscription.respond_with_update(d)

    def report_changes(self, *changes):
        self.q.put(BlockChanges(changes=list(changes)))

    def block_respond(self, response, response_queue):
        self.q.put(BlockRespond(response, response_queue))

    def _handle_block_respond(self, request):
        """Push the response to the required queue"""
        request.response_queue.put(request.response)

    def add_block(self, block):
        """Add a block to be hosted by this process

        Args:
            block (Block): The block to be added
        """
        path = block.path_relative_to(self)
        assert len(path) == 1, \
            "Expected block %r to have %r as parent, got path %r" % \
            (block, self, path)
        name = path[0]
        assert name not in self._blocks, \
            "There is already a block called %r" % name
        request = BlockAdd(block=block, name=name)
        if self._recv_spawned:
            # Started, so call in Process thread
            self.q.put(request)
        else:
            # Not started yet so we are safe to add in this thread
            self._handle_block_add(request)

    def _handle_block_add(self, request):
        """Add a block to be hosted by this process"""
        block = request.block
        assert request.name not in self._blocks, \
            "There is already a block called %r" % request.name
        self._blocks[request.name] = block
        change_request = BlockChanges([[[request.name], block.to_dict()]])
        self._handle_block_changes(change_request)
        # Regenerate list of blocks
        self.process_block["blocks"].set_value(list(self._blocks))

    def get_block(self, block_name):
        try:
            return self._blocks[block_name]
        except KeyError:
            controller = ClientController(block_name, self)
            return controller.block

    def _handle_subscribe(self, request):
        """Add a new subscriber and respond with the current
        sub-structure state"""
        self._subscriptions.append(request)
        d = self._block_state_cache.walk_path(request.endpoint)
        self.log_debug("Initial subscription value %s", d)
        if request.delta:
            request.respond_with_delta([[[], d]])
        else:
            request.respond_with_update(d)

    def _handle_unsubscribe(self, request):
        """Remove a subscriber and respond with success or error"""
        subs = [s for s in self._subscriptions if s.id == request.id]
        # TODO: currently this will remove all subscriptions with a matching id
        #       there should only be one, we may want to warn if we see several
        #       Also, this should only filter by the queue/context, not sure
        #       which yet...
        if len(subs) == 0:
            request.respond_with_error(
                "No subscription found for id %d" % request.id)
        else:
            self._subscriptions = \
                [s for s in self._subscriptions if s.id != request.id]
            request.respond_with_return()

    def _handle_get(self, request):
        d = self._block_state_cache.walk_path(request.endpoint)
        request.respond_with_return(d)
Example #8
0
 def test_notify(self):
     b = Block()
     b.set_parent(MagicMock(), "n")
     b.notify_subscribers()
     b.parent.notify_subscribers.assert_called_once_with("n")
Example #9
0
 def test_add_block(self):
     p = Process("proc", MagicMock())
     b = Block()
     b.set_parent(p, "name")
     p.add_block(b)
     self.assertEqual(p._blocks["name"], b)
Example #10
0
 def test_notify(self):
     b = Block()
     b.set_parent(MagicMock(), "n")
     b.notify_subscribers()
     b.parent.notify_subscribers.assert_called_once_with("n")