def test_add_block_calls_handle(self): s = SyncFactory("sched") p = Process("proc", s) b = Block() b.set_parent(p, "myblock") p.add_block(b) p.start() p.stop() self.assertEqual(len(p._blocks), 2) self.assertEqual(p._blocks, dict(myblock=b, proc=p.process_block))
class TestHandleRequest(unittest.TestCase): def setUp(self): self.block = Block() self.block.set_parent(MagicMock(), "TestBlock") self.method = MagicMock() self.attribute = MagicMock() self.response = MagicMock() self.block.add_method('get_things', self.method) self.block.add_attribute('test_attribute', self.attribute) def test_given_request_then_pass_to_correct_method(self): endpoint = ["TestBlock", "get_things"] request = Post(MagicMock(), MagicMock(), endpoint) self.block.handle_request(request) self.method.get_response.assert_called_once_with(request) response = self.method.get_response.return_value self.block.parent.block_respond.assert_called_once_with( response, request.response_queue) def test_given_put_then_update_attribute(self): endpoint = ["TestBlock", "test_attribute", "value"] value = "5" request = Put(MagicMock(), MagicMock(), endpoint, value) self.block.handle_request(request) self.attribute.put.assert_called_once_with(value) self.attribute.set_value.assert_called_once_with(value) response = self.block.parent.block_respond.call_args[0][0] self.assertEqual("malcolm:core/Return:1.0", response.typeid) self.assertIsNone(response.value) response_queue = self.block.parent.block_respond.call_args[0][1] self.assertEqual(request.response_queue, response_queue) def test_invalid_request_fails(self): request = MagicMock() request.type_ = "Get" self.assertRaises(AssertionError, self.block.handle_request, request) def test_invalid_request_fails(self): endpoint = ["a","b","c","d"] request = Post(MagicMock(), MagicMock(), endpoint) self.assertRaises(ValueError, self.block.handle_request, request) request = Put(MagicMock(), MagicMock(), endpoint) self.assertRaises(ValueError, self.block.handle_request, request)
class TestHandleRequest(unittest.TestCase): def setUp(self): self.block = Block() self.block.set_parent(MagicMock(), "TestBlock") self.method = MagicMock() self.attribute = MagicMock() self.response = MagicMock() self.block.add_method('get_things', self.method) self.block.add_attribute('test_attribute', self.attribute) def test_given_request_then_pass_to_correct_method(self): endpoint = ["TestBlock", "get_things"] request = Post(MagicMock(), MagicMock(), endpoint) self.block.handle_request(request) self.method.get_response.assert_called_once_with(request) response = self.method.get_response.return_value self.block.parent.block_respond.assert_called_once_with( response, request.response_queue) def test_given_put_then_update_attribute(self): endpoint = ["TestBlock", "test_attribute", "value"] value = "5" request = Put(MagicMock(), MagicMock(), endpoint, value) self.block.handle_request(request) self.attribute.put.assert_called_once_with(value) self.attribute.set_value.assert_called_once_with(value) response = self.block.parent.block_respond.call_args[0][0] self.assertEqual("malcolm:core/Return:1.0", response.typeid) self.assertIsNone(response.value) response_queue = self.block.parent.block_respond.call_args[0][1] self.assertEqual(request.response_queue, response_queue) def test_invalid_request_fails(self): request = MagicMock() request.type_ = "Get" self.assertRaises(AssertionError, self.block.handle_request, request) def test_invalid_request_fails(self): endpoint = ["a", "b", "c", "d"] request = Post(MagicMock(), MagicMock(), endpoint) self.assertRaises(ValueError, self.block.handle_request, request) request = Put(MagicMock(), MagicMock(), endpoint) self.assertRaises(ValueError, self.block.handle_request, request)
def test_returns_dict(self): method_dict = OrderedDict(takes=OrderedDict(one=OrderedDict()), returns=OrderedDict(one=OrderedDict()), defaults=OrderedDict()) m1 = MagicMock() m1.to_dict.return_value = method_dict m2 = MagicMock() m2.to_dict.return_value = method_dict a1 = MagicMock() a1dict = OrderedDict(value="test", meta=MagicMock()) a1.to_dict.return_value = a1dict a2 = MagicMock() a2dict = OrderedDict(value="value", meta=MagicMock()) a2.to_dict.return_value = a2dict block = Block() block.set_parent(MagicMock(), "Test") block.add_method('method_one', m1) block.add_method('method_two', m2) block.add_attribute('attr_one', a1) block.add_attribute('attr_two', a2) m1.reset_mock() m2.reset_mock() a1.reset_mock() a2.reset_mock() expected_dict = OrderedDict() expected_dict['typeid'] = "malcolm:core/Block:1.0" expected_dict['attr_one'] = a1dict expected_dict['attr_two'] = a2dict expected_dict['method_one'] = method_dict expected_dict['method_two'] = method_dict response = block.to_dict() m1.to_dict.assert_called_once_with() m2.to_dict.assert_called_once_with() self.assertEqual(expected_dict, response)
class Controller(Loggable): """Implement the logic that takes a Block through its state machine""" # Attributes for all controllers state = None status = None busy = None # BlockMeta for descriptions meta = None def __init__(self, block_name, process, parts=None, params=None): """ Args: process (Process): The process this should run under """ controller_name = "%s(%s)" % (type(self).__name__, block_name) self.set_logger_name(controller_name) self.block = Block() self.log_debug("Creating block %r as %r" % (self.block, block_name)) self.block_name = block_name self.params = params self.process = process self.lock = process.create_lock() # {part: task} self.part_tasks = {} # dictionary of dictionaries # {state (str): {MethodMeta: writeable (bool)} self.methods_writeable = {} # dict {hook: name} self.hook_names = self._find_hooks() self.parts = self._setup_parts(parts, controller_name) self._set_block_children() self._do_transition(sm.DISABLED, "Disabled") self.block.set_parent(process, block_name) process.add_block(self.block) self.do_initial_reset() def _find_hooks(self): hook_names = {} for n in dir(self): attr = getattr(self, n) if isinstance(attr, Hook): assert attr not in hook_names, \ "Hook %s already in controller as %s" % ( n, hook_names[attr]) hook_names[attr] = n return hook_names def _setup_parts(self, parts, controller_name): if parts is None: parts = {} for part_name, part in parts.items(): part.set_logger_name("%s.%s" % (controller_name, part_name)) # Check part hooks into one of our hooks for func_name, part_hook, _ in get_hook_decorated(part): assert part_hook in self.hook_names, \ "Part %s func %s not hooked into %s" % ( part, func_name, self) return parts def do_initial_reset(self): request = Post(None, self.process.create_queue(), [self.block_name, "reset"]) self.process.q.put(request) def add_change(self, changes, item, attr, value): path = item.path_relative_to(self.block) + [attr] changes.append([path, value]) def _set_block_children(self): # reconfigure block with new children child_list = [self.create_meta()] child_list += list(self._create_default_attributes()) child_list += list(self.create_attributes()) child_list += list(self.create_methods()) for part in self.parts.values(): child_list += list(part.create_attributes()) child_list += list(part.create_methods()) self.methods_writeable = {} writeable_functions = {} children = OrderedDict() for name, child, writeable_func in child_list: if isinstance(child, MethodMeta): # Set if the method is writeable if child.only_in is None: states = [ state for state in self.stateMachine.possible_states if state not in (sm.DISABLING, sm.DISABLED) ] else: states = child.only_in for state in states: assert state in self.stateMachine.possible_states, \ "State %s is not one of the valid states %s" % \ (state, self.stateMachine.possible_states) # Make a copy otherwise all instances will own the same one child = MethodMeta.from_dict(child.to_dict()) self.register_method_writeable(child, states) elif isinstance(child, Attribute): child.meta.set_writeable(writeable_func is not None) children[name] = child if writeable_func: writeable_functions[name] = functools.partial( self.call_writeable_function, writeable_func) self.block.replace_endpoints(children) self.block.set_writeable_functions(writeable_functions) def call_writeable_function(self, function, child, *args): with self.lock: if not child.writeable: child.log_error("I'm not writeable") raise ValueError("Child %r is not writeable" % (child, )) result = function(*args) return result def _create_default_attributes(self): # Add the state, status and busy attributes self.state = ChoiceMeta("State of Block", self.stateMachine.possible_states, label="State").make_attribute() yield "state", self.state, None self.status = StringMeta("Status of Block", label="Status").make_attribute() yield "status", self.status, None self.busy = BooleanMeta("Whether Block busy or not", label="Busy").make_attribute() yield "busy", self.busy, None def create_meta(self): self.meta = BlockMeta() return "meta", self.meta, None def create_attributes(self): """Method that should provide Attribute instances for Block Yields: tuple: (string name, Attribute, callable put_function). """ return iter(()) def create_methods(self): """Method that should provide MethodMeta instances for Block Yields: tuple: (string name, MethodMeta, callable post_function). """ return get_method_decorated(self) def transition(self, state, message, create_tasks=False): """ Change to a new state if the transition is allowed Args: state(str): State to transition to message(str): Status message create_tasks(bool): If true then make self.part_tasks """ with self.lock: if self.stateMachine.is_allowed(initial_state=self.state.value, target_state=state): self._do_transition(state, message) if create_tasks: self.part_tasks = self.create_part_tasks() else: raise TypeError("Cannot transition from %s to %s" % (self.state.value, state)) def _do_transition(self, state, message): # transition is allowed, so set attributes changes = [] self.add_change(changes, self.state, "value", state) self.add_change(changes, self.status, "value", message) self.add_change(changes, self.busy, "value", state in self.stateMachine.busy_states) # say which methods can now be called for name in self.block: child = self.block[name] if isinstance(child, MethodMeta): method = child writeable = self.methods_writeable[state][method] self.add_change(changes, method, "writeable", writeable) for ename in method.takes.elements: meta = method.takes.elements[ename] self.add_change(changes, meta, "writeable", writeable) self.log_debug("Transitioning to %s", state) self.block.apply_changes(*changes) def register_method_writeable(self, method, states): """ Set the states that the given method can be called in Args: method(MethodMeta): Method that will be set writeable or not states(list[str]): List of states where method is writeable """ for state in self.stateMachine.possible_states: writeable_dict = self.methods_writeable.setdefault(state, {}) is_writeable = state in states writeable_dict[method] = is_writeable def create_part_tasks(self): part_tasks = {} for part_name, part in self.parts.items(): part_tasks[part] = Task("Task(%s)" % part_name, self.process) return part_tasks def run_hook(self, hook, part_tasks, **kwargs): hook_queue, func_tasks, task_part_names = self.start_hook( hook, part_tasks, **kwargs) return_table = hook.make_return_table(part_tasks) return_dict = self.wait_hook(hook_queue, func_tasks, task_part_names) for part_name in self.parts: return_map = return_dict.get(part_name, None) if return_map: self.fill_in_table(part_name, return_table, return_map) return return_table def fill_in_table(self, part_name, table, return_map): # Find all the array columns arrays = {} for column_name in table.meta.elements: meta = table.meta.elements[column_name] if "hook:return_array" in meta.tags: arrays[column_name] = return_map[column_name] # If there are any arrays, make sure they are the right length lengths = set(len(arr) for arr in arrays.values()) if len(lengths) == 0: # no arrays iterations = 1 else: assert len(lengths) == 1, \ "Varying array length %s for rows %s" % (lengths, arrays) iterations = lengths.pop() for i in range(iterations): row = [] for k in table.endpoints: if k == "name": row.append(part_name) elif k in arrays: row.append(arrays[k][i]) else: row.append(return_map[k]) table.append(row) def make_task_return_value_function(self, hook_queue, **kwargs): def task_return(func, task): try: result = func.MethodMeta.call_post_function(func, kwargs, task) except StopIteration as e: self.log_debug("%s has been aborted", func) result = e except Exception as e: # pylint:disable=broad-except self.log_exception("%s %s raised exception", func, kwargs) result = e self.log_debug("Putting %r on queue", result) hook_queue.put((func, result)) return task_return def start_hook(self, hook, part_tasks, **kwargs): assert hook in self.hook_names, \ "Hook %s doesn't appear in controller hooks %s" % ( hook, self.hook_names) self.log_debug("Running %s hook", self.hook_names[hook]) # ask the hook to find the functions it should run func_tasks = hook.find_func_tasks(part_tasks) # now start them off hook_queue = self.process.create_queue() task_return = self.make_task_return_value_function( hook_queue, **kwargs) for func, task in func_tasks.items(): task.define_spawn_function(task_return, func, task) self.log_debug("Starting task %r", task) task.start() # Create the reverse dictionary so we know where to store the results task_part_names = {} for part_name, part in self.parts.items(): if part in part_tasks: task_part_names[part_tasks[part]] = part_name return hook_queue, func_tasks, task_part_names def wait_hook(self, hook_queue, func_tasks, task_part_names): # Wait for them all to finish return_dict = {} while func_tasks: func, ret = hook_queue.get() task = func_tasks.pop(func) part_name = task_part_names[task] return_dict[part_name] = ret if isinstance(ret, Exception): # Stop all other tasks for task in func_tasks.values(): task.stop() for task in func_tasks.values(): task.wait() # If we got a StopIteration, someone asked us to stop, so # don't wait, otherwise make sure we finished if not isinstance(ret, StopIteration): task.wait() if isinstance(ret, Exception): raise ret return return_dict
class Process(Loggable): """Hosts a number of Blocks, distributing requests between them""" def __init__(self, name, sync_factory): self.set_logger_name(name) self.name = name self.sync_factory = sync_factory self.q = self.create_queue() self._blocks = OrderedDict() # block name -> block self._block_state_cache = Cache() self._recv_spawned = None self._other_spawned = [] self._subscriptions = [] self.comms = [] self._client_comms = OrderedDict() # client comms -> list of blocks self._handle_functions = { Post: self._forward_block_request, Put: self._forward_block_request, Get: self._handle_get, Subscribe: self._handle_subscribe, Unsubscribe: self._handle_unsubscribe, BlockChanges: self._handle_block_changes, BlockRespond: self._handle_block_respond, BlockAdd: self._handle_block_add, BlockList: self._handle_block_list, } self.create_process_block() def recv_loop(self): """Service self.q, distributing the requests to the right block""" while True: request = self.q.get() self.log_debug("Received request %s", request) if request is PROCESS_STOP: # Got the sentinel, stop immediately break try: self._handle_functions[type(request)](request) except Exception as e: # pylint:disable=broad-except self.log_exception("Exception while handling %s", request) request.respond_with_error(str(e)) def add_comms(self, comms): assert not self._recv_spawned, \ "Can't add comms when process has been started" self.comms.append(comms) def start(self): """Start the process going""" self._recv_spawned = self.sync_factory.spawn(self.recv_loop) for comms in self.comms: comms.start() def stop(self, timeout=None): """Stop the process and wait for it to finish Args: timeout (float): Maximum amount of time to wait for each spawned process. None means forever """ assert self._recv_spawned, "Process not started" self.q.put(PROCESS_STOP) for comms in self.comms: comms.stop() # Wait for recv_loop to complete first self._recv_spawned.wait(timeout=timeout) # Now wait for anything it spawned to complete for s in self._other_spawned: s.wait(timeout=timeout) def _forward_block_request(self, request): """Lookup target Block and spawn block.handle_request(request) Args: request (Request): The message that should be passed to the Block """ block_name = request.endpoint[0] block = self._blocks[block_name] self._other_spawned.append( self.sync_factory.spawn(block.handle_request, request)) def create_queue(self): """ Create a queue using sync_factory object Returns: Queue: New queue """ return self.sync_factory.create_queue() def create_lock(self): """ Create a lock using sync_factory object Returns: Lock: New lock """ return self.sync_factory.create_lock() def spawn(self, function, *args, **kwargs): """Calls SyncFactory.spawn()""" spawned = self.sync_factory.spawn(function, *args, **kwargs) self._other_spawned.append(spawned) return spawned def get_client_comms(self, block_name): for client_comms, blocks in list(self._client_comms.items()): if block_name in blocks: return client_comms def create_process_block(self): self.process_block = Block() # TODO: add a meta here children = OrderedDict() children["blocks"] = StringArrayMeta( description="Blocks hosted by this Process" ).make_attribute([]) children["remoteBlocks"] = StringArrayMeta( description="Blocks reachable via ClientComms" ).make_attribute([]) self.process_block.replace_endpoints(children) self.process_block.set_parent(self, self.name) self.add_block(self.process_block) def update_block_list(self, client_comms, blocks): self.q.put(BlockList(client_comms=client_comms, blocks=blocks)) def _handle_block_list(self, request): self._client_comms[request.client_comms] = request.blocks remotes = [] for blocks in self._client_comms.values(): remotes += [b for b in blocks if b not in remotes] self.process_block["remoteBlocks"].set_value(remotes) def _handle_block_changes(self, request): """Update subscribers with changes and applies stored changes to the cached structure""" # update cached dict self._block_state_cache.apply_changes(*request.changes) for subscription in self._subscriptions: endpoint = subscription.endpoint # find stuff that's changed that is relevant to this subscriber changes = [] for change in request.changes: change_path = change[0] # look for a change_path where the beginning matches the # endpoint path, then strip away the matching part and add # to the change set i = 0 for (cp_element, ep_element) in zip(change_path, endpoint): if cp_element != ep_element: break i += 1 else: # change has matching path, so keep it # but strip off the end point path filtered_change = [change_path[i:]] + change[1:] changes.append(filtered_change) if len(changes) > 0: if subscription.delta: # respond with the filtered changes subscription.respond_with_delta(changes) else: # respond with the structure of everything # below the endpoint d = self._block_state_cache.walk_path(endpoint) subscription.respond_with_update(d) def report_changes(self, *changes): self.q.put(BlockChanges(changes=list(changes))) def block_respond(self, response, response_queue): self.q.put(BlockRespond(response, response_queue)) def _handle_block_respond(self, request): """Push the response to the required queue""" request.response_queue.put(request.response) def add_block(self, block): """Add a block to be hosted by this process Args: block (Block): The block to be added """ path = block.path_relative_to(self) assert len(path) == 1, \ "Expected block %r to have %r as parent, got path %r" % \ (block, self, path) name = path[0] assert name not in self._blocks, \ "There is already a block called %r" % name request = BlockAdd(block=block, name=name) if self._recv_spawned: # Started, so call in Process thread self.q.put(request) else: # Not started yet so we are safe to add in this thread self._handle_block_add(request) def _handle_block_add(self, request): """Add a block to be hosted by this process""" block = request.block assert request.name not in self._blocks, \ "There is already a block called %r" % request.name self._blocks[request.name] = block change_request = BlockChanges([[[request.name], block.to_dict()]]) self._handle_block_changes(change_request) # Regenerate list of blocks self.process_block["blocks"].set_value(list(self._blocks)) def get_block(self, block_name): try: return self._blocks[block_name] except KeyError: controller = ClientController(block_name, self) return controller.block def _handle_subscribe(self, request): """Add a new subscriber and respond with the current sub-structure state""" self._subscriptions.append(request) d = self._block_state_cache.walk_path(request.endpoint) self.log_debug("Initial subscription value %s", d) if request.delta: request.respond_with_delta([[[], d]]) else: request.respond_with_update(d) def _handle_unsubscribe(self, request): """Remove a subscriber and respond with success or error""" subs = [s for s in self._subscriptions if s.id == request.id] # TODO: currently this will remove all subscriptions with a matching id # there should only be one, we may want to warn if we see several # Also, this should only filter by the queue/context, not sure # which yet... if len(subs) == 0: request.respond_with_error( "No subscription found for id %d" % request.id) else: self._subscriptions = \ [s for s in self._subscriptions if s.id != request.id] request.respond_with_return() def _handle_get(self, request): d = self._block_state_cache.walk_path(request.endpoint) request.respond_with_return(d)
def test_notify(self): b = Block() b.set_parent(MagicMock(), "n") b.notify_subscribers() b.parent.notify_subscribers.assert_called_once_with("n")
def test_add_block(self): p = Process("proc", MagicMock()) b = Block() b.set_parent(p, "name") p.add_block(b) self.assertEqual(p._blocks["name"], b)