Ejemplo n.º 1
0
    def __init__(self, argv=None, run_now: bool = True):
        """Initialize a Silk log replayer.

        Args:
            argv (List[str], optional): command line arguments. Defaults to None.
            run_now (bool, optional): if the replayer should start running immediately. Useful to set to False to
                run tests on this class.
        """
        args = SilkReplayer.parse_args(argv)
        self.verbosity = args.verbosity
        self.logger = logging.getLogger("silk_replay")
        self.device_names = None
        self.device_name_map = None

        self.input_path = args.path
        self.log_filename = os.path.basename(args.path)
        self.speed = float(args.playback_speed)

        self.set_up_logger(args.results_dir or os.getcwd())
        self.acquire_devices(args.hw_conf_file)

        self.otns_manager = OtnsManager(
            server_host=args.otns_server,
            logger=self.logger.getChild("otnsManager"))

        self.last_time = None

        if run_now:
            self.run()
            if args.results_dir:
                result_path = os.path.join(
                    args.results_dir,
                    f"silk_replay_summary_for_{self.log_filename}.csv")
                self.output_summary(coalesced=True, csv_path=result_path)
Ejemplo n.º 2
0
 def setUp(self):
     """Test method set up.
     """
     self.ns = OTNS(otns_args=[
         "-raw", "-real", "-ot-cli", "otns-silk-proxy", "-listen", ":9000",
         "-log", "debug"
     ])
     # wait for OTNS gRPC server to start
     time.sleep(0.3)
     self.manager = OtnsManager("localhost",
                                self.logger.getChild("OtnsManager"))
     self.manager.wait_for_grpc_channel_ready(10)
Ejemplo n.º 3
0
    def setUp(self):
        """Test method set up.
        """
        self.exception_queue = queue.Queue()

        self.manager = OtnsManager("localhost",
                                   self.logger.getChild("OtnsManager"))
        self.grpc_client = MockGrpcClient(
            self.exception_queue, self.logger.getChild("MockGrpcClient"))
        self.manager.grpc_client = self.grpc_client

        self.udp_server = MockUDPServer(self.exception_queue)
Ejemplo n.º 4
0
    def wrapper(*args, **kwargs):
        start_time_string = time.strftime(DATE_FORMAT)
        cls = args[0]

        cls.thread_sniffers.clear()

        if OUTPUT_DIRECTORY_KEY in os.environ:
            cls.top_output_directory = os.environ[OUTPUT_DIRECTORY_KEY]
        else:
            cls.top_output_directory = silk.config.defaults.DEFAULT_LOG_OUTPUT_DIRECTORY

        # Set the test class name
        cls.current_test_class = cls.__name__

        # Create an output directory
        cls.current_output_directory = os.path.join(
            cls.top_output_directory,
            start_time_string + "_" + cls.current_test_class)
        os.makedirs(cls.current_output_directory)

        # Create a new logger for the test framework
        silk_log_dest = os.path.join(cls.current_output_directory, "silk.log")
        cls.logger = get_framework_logger(silk_log_dest)
        cls.logger.info("Log dest: %s" % cls.current_output_directory)
        cls.logger.info("SET UP CLASS %s" % cls.current_test_class)

        # Establish a results dictionary
        try:
            cls.results[cls.current_test_class] = collections.OrderedDict()
        except:
            cls.results = collections.OrderedDict()
            cls.results[cls.current_test_class] = collections.OrderedDict()

        curr_suite_id = cls._testrail_dict_lookup(SUITE_ID)

        cls.results[cls.current_test_class][SUITE_ID] = curr_suite_id

        if _OTNS_HOST:
            cls.otns_manager = OtnsManager(
                server_host=_OTNS_HOST,
                logger=cls.logger.getChild("otnsManager"))
            cls.otns_manager.set_test_title(f"{cls.current_test_class}.set_up")
            cls.otns_manager.set_replay_speed(1.0)
        else:
            cls.otns_manager = None

        # Call the user's setUpClass
        try:
            func(*args, **kwargs)

            # If the user's setUpClass call succeeded, try to Thread sniffers
            cls.thread_sniffer_start_all()

            if _OTNS_HOST:
                for device in cls.device_list:
                    cls.get_device_extaddr(device)
        except HardwareNotFound as e:
            cls.results[cls.current_test_class]["setupClass"] = False

            cls.release_devices()
            cls.logger.error("Hardware Not Found Error !!!")
            output_file = open(
                os.path.join(cls.current_output_directory, "results.json"),
                "w")
            json.dump(cls.results, output_file, indent=4)
            output_file.close()

            raise

        except:
            stack = sys.exc_info()
            for call in traceback.format_tb(stack[2]):
                for line in call.rstrip().splitlines():
                    cls.logger.error(line)

            cls.results[cls.current_test_class]["setupClass"] = False
            cls.logger.error("Hardware Configuration Error !!!")
            output_file = open(
                os.path.join(cls.current_output_directory, "results.json"),
                "w")
            json.dump(cls.results, output_file, indent=4)
            output_file.close()

            cls.logger.info("=" * 70)
            cls.logger.info("=" * 20 + " CHECK HARDWARE CONFIGURATION " +
                            "=" * 19)
            cls.logger.info("=" * 70)
            cls.logger.info(("%s" % cls.__name__).ljust(34, ".") +
                            "FAILED SETUPCLASS".rjust(34, "."))
            cls.logger.info("=" * 70)
            cls.release_devices()
            raise
Ejemplo n.º 5
0
class SilkReplayer(object):
    """Replay topology changes and transmissions from a Silk log.

    Attributes:
        speed (float): speed ratio for the replay. 2.0 means speeding up to 2x.
        verbosity (int): terminal log verbosity.
        input_path (str): input Silk log file path.
        log_filename (str): name of input log file.
        logger (logging.Logger): logger for the replayer.

        device_names (Set[str]): name of hardware modules from hwconfig.ini file.
        device_name_map (Dict[str, ThreadDevBoard]): map from device name to
        ThreadDevBoard instance.
        otns_manager (OtnsManager): manager for OTNS communications.

        last_time (datetime.datetime): timestamp of the last line of log processed.
    """
    def __init__(self, argv=None, run_now: bool = True):
        """Initialize a Silk log replayer.

        Args:
            argv (List[str], optional): command line arguments. Defaults to None.
            run_now (bool, optional): if the replayer should start running immediately. Useful to set to False to
                run tests on this class.
        """
        args = SilkReplayer.parse_args(argv)
        self.verbosity = args.verbosity
        self.logger = logging.getLogger("silk_replay")
        self.device_names = None
        self.device_name_map = None

        self.input_path = args.path
        self.log_filename = os.path.basename(args.path)
        self.speed = float(args.playback_speed)

        self.set_up_logger(args.results_dir or os.getcwd())
        self.acquire_devices(args.hw_conf_file)

        self.otns_manager = OtnsManager(
            server_host=args.otns_server,
            logger=self.logger.getChild("otnsManager"))

        self.last_time = None

        if run_now:
            self.run()
            if args.results_dir:
                result_path = os.path.join(
                    args.results_dir,
                    f"silk_replay_summary_for_{self.log_filename}.csv")
                self.output_summary(coalesced=True, csv_path=result_path)

    def set_up_logger(self, result_dir: str):
        """Set up logger for the replayer.

        Args:
            result_dir (str): output directory of log.
        """
        self.logger.setLevel(logging.DEBUG)

        if self.verbosity == 0:
            stream_level = logging.CRITICAL
        elif self.verbosity == 1:
            stream_level = logging.INFO
        else:
            stream_level = logging.DEBUG

        logging.basicConfig(format=LOG_LINE_FORMAT, level=stream_level)

        formatter = logging.Formatter(LOG_LINE_FORMAT)

        result_path = os.path.join(
            result_dir, f"silk_replay_log_for_{self.log_filename}.log")

        file_handler = logging.FileHandler(result_path, mode="w")
        file_handler.setLevel(logging.DEBUG)
        file_handler.setFormatter(formatter)

        if self.logger.hasHandlers():
            self.logger.handlers.clear()
        self.logger.addHandler(file_handler)

    @staticmethod
    def parse_args(argv):
        """Parse arguments.

        Args:
            argv (List[str]): command line arguments.

        Returns:
            argparse.Namespace: parsed arguments attributes.
        """
        parser = argparse.ArgumentParser(description="Replay a Silk test log")
        parser.add_argument("-r",
                            "--results_dir",
                            dest="results_dir",
                            metavar="ResPath",
                            help="Set the path for run results")
        parser.add_argument("-c",
                            "--hwconfig",
                            dest="hw_conf_file",
                            metavar="ConfFile",
                            default="/opt/openthread_test/hwconfig.ini",
                            help="Name the hardware config file")
        parser.add_argument(
            "-v",
            "--verbose",
            "--verbosity",
            type=int,
            default=1,
            choices=list(range(0, 3)),
            dest="verbosity",
            metavar="X",
            help="Verbosity level (0=quiet, 1=default, 2=verbose)")
        parser.add_argument("-s",
                            "--otns",
                            dest="otns_server",
                            metavar="OtnsServer",
                            default="localhost",
                            help="OTNS server address")
        parser.add_argument("-p",
                            "--speed",
                            dest="playback_speed",
                            type=float,
                            default=1.0,
                            metavar="PlaybackSpeed",
                            help="Speed of log replay")
        parser.add_argument("path", metavar="P", help="Log file path")
        return parser.parse_args(argv[1:])

    def acquire_devices(self, config_file: str):
        """Acquire devices from hwconfig.ini file.

        Args:
            config_file (str): path to hwconfig.ini file.
        """
        hw_resource.global_instance(config_file, virtual=True)
        hw_resource.global_instance().load_config()
        self.device_names = set(
            hw_resource.global_instance().get_hw_module_names())
        self.device_name_map = dict()
        self.logger.debug("Loaded devices %s" % self.device_names)

    def execute_message(self, entity_name: str, message: str,
                        timestamp: datetime):
        """Execute the intended action represented by the message.

        Args:
            entity_name (str): name of the entity carrying out the action.
            message (str): message content of the action.
            timestamp (datetime.datetime): timestamp of the message.
        """
        parts = entity_name.split(".")
        if len(parts) == 1 and parts[0] == "silk":
            set_up_class_match = re.match(RegexType.SET_UP_CLASS.value,
                                          message)
            if set_up_class_match:
                self.otns_manager.set_test_title(
                    f"{set_up_class_match.group(1)}.set_up")
                return

            teardown_class_done_match = re.match(
                RegexType.TEARDOWN_CLASS_DONE.value, message)
            if teardown_class_done_match:
                self.otns_manager.set_test_title("")
                return

            teardown_class_match = re.match(RegexType.TEARDOWN_CLASS.value,
                                            message)
            if teardown_class_match:
                self.otns_manager.set_test_title(
                    f"{teardown_class_match.group(1)}.tear_down")
                return

            running_test_match = re.match(RegexType.RUNNING_TEST.value,
                                          message)
            if running_test_match:
                self.otns_manager.set_test_title(running_test_match.group(1))
                return
        if len(parts) < 2 or parts[0] != "silk" or parts[1] == "otnsManager":
            return

        device_name = parts[1]
        if device_name not in self.device_names:
            return

        if device_name not in self.device_name_map:
            device = ffdb.ThreadDevBoard(virtual=True,
                                         virtual_name=device_name)
            self.device_name_map[device_name] = device
        else:
            device = self.device_name_map[device_name]

        start_match = re.match(OtnsRegexType.START_WPANTUND_RES.value, message)
        if start_match:
            self.otns_manager.add_node(device)
            return

        stop_match = re.match(OtnsRegexType.STOP_WPANTUND_REQ.value, message)
        if stop_match:
            self.otns_manager.remove_node(device)
            return

        extaddr_match = re.search(OtnsRegexType.GET_EXTADDR_RES.value, message)
        if extaddr_match:
            self.otns_manager.update_extaddr(device,
                                             int(extaddr_match.group(1), 16),
                                             time=timestamp)
            return

        ncp_version_match = re.search(OtnsRegexType.NCP_VERSION.value, message)
        if ncp_version_match:
            ncp_version = ncp_version_match.group(1)
            self.otns_manager.set_ncp_version(ncp_version)
            return

        status_match = re.match(OtnsRegexType.STATUS.value, message)
        if status_match:
            self.otns_manager.process_node_status(device,
                                                  message,
                                                  time=timestamp)
            return

    def output_summary(self, coalesced: bool, csv_path: str):
        """Print summary of the replayed log.

        Args:
            coalesced (bool): if the summary should be printed grouped by time.
            csv_path (str): path to CSV output file
        """
        extaddr_map = {}
        for summary in self.otns_manager.node_summaries.values():
            if summary.extaddr_history:
                extaddr_map[summary.extaddr_history[-1][1]] = summary.node_id
        if csv_path:
            collection = OtnsNodeSummaryCollection(
                self.otns_manager.node_summaries.values())
            data_frame = collection.to_csv(extaddr_map)
            data_frame.to_csv(csv_path, index=False)
        elif coalesced:
            collection = OtnsNodeSummaryCollection(
                self.otns_manager.node_summaries.values())
            self.logger.debug(collection.to_string(extaddr_map))
        else:
            for summary in self.otns_manager.node_summaries.values():
                self.logger.debug(summary.to_string(extaddr_map))

    def run(self, start_line: int = 0, stop_regex: str = None) -> int:
        """Run the Silk log replayer.

        This method provides two optional arguments to allow for unit testing.

        Args:
            start_line (int, optional): start reading the log file at the specified line number. Defaults to 0.
            stop_regex (str, optional): stop running if the pattern matches a log line. Defaults to None.

        Returns:
            int: the last processed line number.
        """
        self.otns_manager.set_replay_speed(self.speed)
        with open(file=self.input_path, mode="r") as file:
            for line_number, line in enumerate(file):
                if line_number < start_line:
                    continue
                if stop_regex and re.search(stop_regex, line):
                    return line_number
                line_match = re.search(RegexType.LOG_LINE.value, line)
                if line_match:
                    timestamp = datetime.strptime(line_match.group(1),
                                                  DATE_FORMAT)
                    if not self.last_time:
                        self.last_time = timestamp
                    time_diff = timestamp - self.last_time
                    delay = time_diff.total_seconds() / self.speed
                    self.last_time = timestamp

                    entity_name = line_match.group(2)
                    message = line_match.group(4)

                    # delay for the time difference between two log lines
                    if delay > 0:
                        time.sleep(delay)
                    self.execute_message(entity_name, message, timestamp)

            return line_number
Ejemplo n.º 6
0
class OTNSIntegrationTest(SilkTestCase):
    """Silk integration test case for OTNS.
    """
    def setUp(self):
        """Test method set up.
        """
        self.ns = OTNS(otns_args=[
            "-raw", "-real", "-ot-cli", "otns-silk-proxy", "-listen", ":9000",
            "-log", "debug"
        ])
        # wait for OTNS gRPC server to start
        time.sleep(0.3)
        self.manager = OtnsManager("localhost",
                                   self.logger.getChild("OtnsManager"))
        self.manager.wait_for_grpc_channel_ready(10)

    def tearDown(self) -> None:
        """Test method tear down.
        """
        self.manager.unsubscribe_from_all_nodes()
        self.manager.remove_all_nodes()
        self.ns.close()
        # wait for OTNS gRPC server to stop
        time.sleep(0.2)

    def assert_device_positions(self, nodes_info: Dict[int, Dict[str, Any]],
                                expected_coords: Dict[int, Tuple[int, int]]):
        """Helper method to assert auto layout position devices coordinates.

        Args:
            nodes_info (Dict[int, Dict[str, Any]]): nodes info dictionary.
            expected_coords (Dict[int, Tuple[int, int]]): dict mapping device id to coordinates to check.
        """
        for device_id, coords in expected_coords.items():
            self.assertAlmostEqual(nodes_info[device_id]["x"],
                                   coords[0],
                                   delta=1)
            self.assertAlmostEqual(nodes_info[device_id]["y"],
                                   coords[1],
                                   delta=1)

    def testAddDevice(self):
        """Test adding device.
        """
        ns = self.ns
        manager = self.manager

        device = MockThreadDevBoard(1)
        manager.add_node(device)
        ns.go(0.1)
        self.assertEqual(len(ns.nodes()), 1)

    def testRemoveDevice(self):
        """Test removing device.
        """
        ns = self.ns
        manager = self.manager

        device = MockThreadDevBoard(1)
        manager.add_node(device)
        ns.go(0.1)
        self.assertEqual(len(ns.nodes()), 1)

        manager.remove_node(device)
        ns.go(0.1)
        self.assertEqual(len(ns.nodes()), 0)

    def testSetSpeed(self):
        """Test setting speed display.
        """
        ns = self.ns
        manager = self.manager

        speed = random.randint(2, 20)
        manager.set_replay_speed(speed)
        self.assertAlmostEqual(ns.speed, speed)

        speed = random.randint(21, 40)
        manager.set_replay_speed(speed)
        self.assertAlmostEqual(ns.speed, speed)

    def testAddFixedPositionDevices(self):
        """Test adding fixed position nodes.
        """
        def assert_device_fixed_positions(devices: List[MockThreadDevBoard]):
            """Helper method to assert fixed position devices coordinates.

            Args:
                devices (List[MockThreadDevBoard]): list of devices to check.
            """
            for a_device in devices:
                self.assertEqual(nodes_info[a_device.id]["x"], a_device.x)
                self.assertEqual(nodes_info[a_device.id]["y"], a_device.y)

        ns = self.ns
        manager = self.manager

        device_1 = MockThreadDevBoard(random.randint(1, 10))
        device_2 = MockThreadDevBoard(random.randint(11, 20))
        device_3 = MockThreadDevBoard(random.randint(21, 30))

        for device in [device_1, device_2, device_3]:
            device.device.set_otns_vis_position(random.randint(100, 200),
                                                random.randint(100, 200))

        manager.add_node(device_1)
        manager.add_node(device_2)
        ns.go(0.1)

        nodes_info = ns.nodes()
        self.assertEqual(len(nodes_info), 2)
        assert_device_fixed_positions([device_1, device_2])

        manager.add_node(device_3)
        ns.go(0.1)

        nodes_info = ns.nodes()
        self.assertEqual(len(nodes_info), 3)
        assert_device_fixed_positions([device_1, device_2, device_3])

    def testAddAutoLayoutDevices(self):
        """Test adding auto layout nodes.
        """
        ns = self.ns
        manager = self.manager

        layout_center_x = random.randint(100, 200)
        layout_center_y = random.randint(100, 200)
        layout_radius = random.randint(50, 100)

        device_1 = MockThreadDevBoard(1)
        device_2 = MockThreadDevBoard(2)
        device_3 = MockThreadDevBoard(3)
        device_4 = MockThreadDevBoard(4)

        for device in [device_1, device_2, device_3, device_4]:
            device.device.set_otns_layout_parameter(layout_center_x,
                                                    layout_center_y,
                                                    layout_radius)

        manager.add_node(device_1)
        ns.go(0.1)

        nodes_info = ns.nodes()
        self.assertEqual(len(nodes_info), 1)
        # placing the first node alone
        expected_coords = {
            device_1.id: (layout_center_x + layout_radius, layout_center_y)
        }
        nodes_info = ns.nodes()
        self.assert_device_positions(nodes_info, expected_coords)

        manager.add_node(device_2)
        ns.go(0.1)

        nodes_info = ns.nodes()
        self.assertEqual(len(nodes_info), 2)
        # forming a horizontal line
        expected_coords = {
            device_1.id: (layout_center_x - layout_radius, layout_center_y),
            device_2.id: (layout_center_x + layout_radius, layout_center_y)
        }
        self.assert_device_positions(nodes_info, expected_coords)

        manager.add_node(device_3)
        manager.add_node(device_4)
        ns.go(0.1)

        nodes_info = ns.nodes()
        self.assertEqual(len(nodes_info), 4)
        # forming a cross shape
        expected_coords = {
            device_1.id: (layout_center_x, layout_center_y + layout_radius),
            device_2.id: (layout_center_x - layout_radius, layout_center_y),
            device_3.id: (layout_center_x, layout_center_y - layout_radius),
            device_4.id: (layout_center_x + layout_radius, layout_center_y)
        }
        self.assert_device_positions(nodes_info, expected_coords)

    def testRemoveAutoLayoutDevices(self):
        """Test that removing nodes keeps other nodes stationary with auto layout.
        """
        ns = self.ns
        manager = self.manager

        layout_center_x = random.randint(100, 200)
        layout_center_y = random.randint(100, 200)
        layout_radius = random.randint(50, 100)

        device_1 = MockThreadDevBoard(1)
        device_2 = MockThreadDevBoard(2)
        device_3 = MockThreadDevBoard(3)
        device_4 = MockThreadDevBoard(4)

        for device in [device_1, device_2, device_3, device_4]:
            device.device.set_otns_layout_parameter(layout_center_x,
                                                    layout_center_y,
                                                    layout_radius)
            manager.add_node(device)

        ns.go(0.1)

        nodes_info = ns.nodes()
        self.assertEqual(len(nodes_info), 4)
        expected_coords = {
            device_1.id: (layout_center_x, layout_center_y + layout_radius),
            device_2.id: (layout_center_x - layout_radius, layout_center_y),
            device_3.id: (layout_center_x, layout_center_y - layout_radius),
            device_4.id: (layout_center_x + layout_radius, layout_center_y)
        }
        self.assert_device_positions(nodes_info, expected_coords)

        manager.remove_node(device_4)
        ns.go(0.1)
        nodes_info = ns.nodes()
        self.assertEqual(len(nodes_info), 3)
        expected_coords = {
            device_1.id: (layout_center_x, layout_center_y + layout_radius),
            device_2.id: (layout_center_x - layout_radius, layout_center_y),
            device_3.id: (layout_center_x, layout_center_y - layout_radius)
        }
        self.assert_device_positions(nodes_info, expected_coords)

        manager.remove_node(device_3)
        ns.go(0.1)
        nodes_info = ns.nodes()
        self.assertEqual(len(nodes_info), 2)
        expected_coords = {
            device_1.id: (layout_center_x, layout_center_y + layout_radius),
            device_2.id: (layout_center_x - layout_radius, layout_center_y)
        }
        self.assert_device_positions(nodes_info, expected_coords)

        manager.remove_node(device_2)
        ns.go(0.1)
        nodes_info = ns.nodes()
        self.assertEqual(len(nodes_info), 1)
        expected_coords = {
            device_1.id: (layout_center_x, layout_center_y + layout_radius)
        }
        self.assert_device_positions(nodes_info, expected_coords)

        manager.add_node(device_2)
        manager.remove_node(device_1)
        ns.go(0.1)
        nodes_info = ns.nodes()
        self.assertEqual(len(nodes_info), 1)
        expected_coords = {
            device_2.id: (layout_center_x - layout_radius, layout_center_y)
        }
        self.assert_device_positions(nodes_info, expected_coords)

        manager.add_node(device_3)
        manager.remove_node(device_2)
        ns.go(0.1)
        nodes_info = ns.nodes()
        self.assertEqual(len(nodes_info), 1)
        expected_coords = {
            device_3.id: (layout_center_x, layout_center_y - layout_radius)
        }
        self.assert_device_positions(nodes_info, expected_coords)

        manager.add_node(device_4)
        manager.remove_node(device_3)
        ns.go(0.1)
        nodes_info = ns.nodes()
        self.assertEqual(len(nodes_info), 1)
        expected_coords = {
            device_4.id: (layout_center_x + layout_radius, layout_center_y)
        }
        self.assert_device_positions(nodes_info, expected_coords)

    def testUpdateExtaddr(self):
        """Test updating node extended address.

        Also tests updating before the OTNS manager subscribes to the node.
        """
        ns = self.ns
        manager = self.manager

        device_extaddr = random.getrandbits(64)
        device = MockThreadDevBoard(random.randint(1, 10))

        manager.add_node(device)
        ns.go(0.1)

        self.assertEqual(ns.nodes()[device.id]["extaddr"], device.id)

        device.wpantund_process.emit_status(f"extaddr={device_extaddr:016x}")
        ns.go(0.1)

        self.assertEqual(ns.nodes()[device.id]["extaddr"], device.id)

        manager.subscribe_to_node(device)
        device.wpantund_process.emit_status(f"extaddr={device_extaddr:016x}")
        ns.go(0.1)

        self.assertEqual(ns.nodes()[device.id]["extaddr"], device_extaddr)

    def testUpdateRLOC16(self):
        """Test updating node RLOC16.

        Also tests updating before the OTNS manager subscribes to the node.
        """
        ns = self.ns
        manager = self.manager

        device_rloc16 = random.getrandbits(16)
        device = MockThreadDevBoard(random.randint(1, 10))

        manager.add_node(device)
        ns.go(0.1)

        original_rloc16 = ns.nodes()[device.id]["rloc16"]

        device.wpantund_process.emit_status(f"rloc16={device_rloc16}")
        ns.go(0.1)

        self.assertEqual(ns.nodes()[device.id]["rloc16"], original_rloc16)

        manager.subscribe_to_node(device)
        device.wpantund_process.emit_status(f"rloc16={device_rloc16}")
        ns.go(0.1)

        self.assertEqual(ns.nodes()[device.id]["rloc16"], device_rloc16)

    def testFormPartition(self):
        """Test forming a partition.
        """
        ns = self.ns
        manager = self.manager

        device_1_parid = random.getrandbits(16)
        device_1 = MockThreadDevBoard(random.randint(1, 10))

        device_2_parid = random.getrandbits(16)
        device_2 = MockThreadDevBoard(random.randint(11, 20))

        manager.add_node(device_1)
        manager.add_node(device_2)

        manager.subscribe_to_node(device_1)
        manager.subscribe_to_node(device_2)

        device_1.wpantund_process.emit_status(f"parid={device_1_parid:08x}")
        device_2.wpantund_process.emit_status(f"parid={device_2_parid:08x}")
        ns.go(0.1)

        partitions_info = ns.partitions()
        self.assertEqual(len(partitions_info), 2)
        self.assertEqual(len(partitions_info[device_1_parid]), 1)
        self.assertEqual(len(partitions_info[device_2_parid]), 1)
        self.assertEqual(partitions_info[device_1_parid][0], device_1.id)
        self.assertEqual(partitions_info[device_2_parid][0], device_2.id)

        device_2.wpantund_process.emit_status(f"parid={device_1_parid:08x}")
        ns.go(0.1)

        partitions_info = ns.partitions()
        self.assertEqual(len(partitions_info), 1)
        self.assertEqual(len(partitions_info[device_1_parid]), 2)
        self.assertIn(device_1.id, partitions_info[device_1_parid])
        self.assertIn(device_2.id, partitions_info[device_1_parid])

        device_2.wpantund_process.emit_status(f"parid={device_2_parid:08x}")
        ns.go(0.1)

        partitions_info = ns.partitions()
        self.assertEqual(len(partitions_info), 2)
        self.assertEqual(len(partitions_info[device_1_parid]), 1)
        self.assertEqual(len(partitions_info[device_2_parid]), 1)
        self.assertEqual(partitions_info[device_1_parid][0], device_1.id)
        self.assertEqual(partitions_info[device_2_parid][0], device_2.id)
Ejemplo n.º 7
0
class SilkMockingTestCase(SilkTestCase):
    """Silk test case with basic mocked OTNS and manager set up.
    """
    def setUp(self):
        """Test method set up.
        """
        self.exception_queue = queue.Queue()

        self.manager = OtnsManager("localhost",
                                   self.logger.getChild("OtnsManager"))
        self.grpc_client = MockGrpcClient(
            self.exception_queue, self.logger.getChild("MockGrpcClient"))
        self.manager.grpc_client = self.grpc_client

        self.udp_server = MockUDPServer(self.exception_queue)

    def tearDown(self):
        """Test method tear down. Clean up fixtures.
        """
        self.manager.unsubscribe_from_all_nodes()
        self.manager.remove_all_nodes()
        self.udp_server.close()

    def wait_for_expect(self, expect_thread: threading.Thread):
        """Wait for expectation to be fulfilled.

        Args:
            expect_thread (threading.Thread): thread running expectation.
        """
        while True:
            try:
                exception = self.exception_queue.get(block=False)
            except queue.Empty:
                pass
            else:
                self.fail(exception)

            if expect_thread.is_alive():
                expect_thread.join(0.1)
            else:
                break

    def expect_grpc_commands(self, commands: List[str]) -> threading.Thread:
        """Create a thread for an expecting gRPC commands.

        Args:
            commands (List[str]): expecting gRPC commands.

        Returns:
            threading.Thread: thread running the expectation.
        """
        expect_thread = threading.Thread(
            target=self.grpc_client.expect_commands, args=(commands, ))
        expect_thread.start()
        return expect_thread

    def expect_udp_messages(
            self, messages: List[Tuple[str, int]]) -> threading.Thread:
        """Create a thread for an expecting UDP message.

        Args:
            messages (List[Tuple[str, int]]): list of expected UDP messages and corresponding source ID.

        Returns:
            threading.Thread: thread running the expectation.
        """
        # convert source IDs to source ports
        messages = [(message, 9000 + source_id)
                    for message, source_id in messages]
        expect_thread = threading.Thread(
            target=self.udp_server.expect_messages, args=(messages, ))
        expect_thread.start()
        return expect_thread