Beispiel #1
0
class Node:
    def __init__(self, id, ranges, rules, depth, partitions, manual_partition):
        self.id = id
        self.partitions = list(partitions or [])
        self.manual_partition = manual_partition
        self.ranges = ranges
        self.rules = RuleSet(rules) if isinstance(rules, list) else rules
        self.depth = depth
        self.children = []
        self.action = None
        self.pushup_rules = None
        self.num_rules = len(self.rules)

    def is_partition(self):
        """Returns if node was partitioned."""
        if not self.action:
            return False
        elif self.action[0] == "partition":
            return True
        elif self.action[0] == "cut":
            return False
        else:
            return False

    def match(self, packet):
        if self.is_partition():
            matches = []
            for c in self.children:
                match = c.match(packet)
                if match:
                    matches.append(match)
            if matches:
                matches.sort(key=lambda r: self.rules.index(r))
                return matches[0]
            return None
        elif self.children:
            for n in self.children:
                if n.contains(packet):
                    return n.match(packet)
            return None
        else:
            for r in self.rules:
                if r.matches(packet):
                    return r

    def is_intersect_multi_dimension(self, ranges):
        for i in range(5):
            if ranges[i*2] >= self.ranges[i*2+1] or \
                    ranges[i*2+1] <= self.ranges[i*2]:
                return False
        return True

    def contains(self, packet):
        assert len(packet) == 5, packet
        return self.is_intersect_multi_dimension([
            packet[0] + 0,  # src ip
            packet[0] + 1,
            packet[1] + 0,  # dst ip
            packet[1] + 1,
            packet[2] + 0,  # src port
            packet[2] + 1,
            packet[3] + 0,  # dst port
            packet[3] + 1,
            packet[4] + 0,  # protocol
            packet[4] + 1
        ])

    def is_useless(self):
        if not self.children:
            return False
        return max(len(c.rules) for c in self.children) == len(self.rules)

    def pruned_rules(self):
        return self.rules.prune(self.ranges)

    def get_state(self):
        state = []
        state.extend(to_bits(self.ranges[0], 32))
        state.extend(to_bits(self.ranges[1] - 1, 32))
        state.extend(to_bits(self.ranges[2], 32))
        state.extend(to_bits(self.ranges[3] - 1, 32))
        assert len(state) == 128, len(state)
        state.extend(to_bits(self.ranges[4], 16))
        state.extend(to_bits(self.ranges[5] - 1, 16))
        state.extend(to_bits(self.ranges[6], 16))
        state.extend(to_bits(self.ranges[7] - 1, 16))
        assert len(state) == 192, len(state)
        state.extend(to_bits(self.ranges[8], 8))
        state.extend(to_bits(self.ranges[9] - 1, 8))
        assert len(state) == 208, len(state)

        if self.manual_partition is None:
            # 0, 6 -> 0-64%
            # 6, 7 -> 64-100%
            partition_state = [
                0,
                7,  # [>=min, <max) -- 0%, 2%, 4%, 8%, 16%, 32%, 64%, 100%
                0,
                7,
                0,
                7,
                0,
                7,
                0,
                7,
            ]
            for (smaller, part_dim, part_size) in self.partitions:
                if smaller:
                    partition_state[part_dim * 2 + 1] = min(
                        partition_state[part_dim * 2 + 1], part_size + 1)
                else:
                    partition_state[part_dim * 2] = max(
                        partition_state[part_dim * 2], part_size + 1)
            state.extend(onehot_encode(partition_state, 7))
        else:
            partition_state = [0] * 70
            partition_state[self.manual_partition] = 1
            state.extend(partition_state)
        state.append(self.num_rules)
        return np.array(state)

    def __str__(self):
        result = "ID:%d\tAction:%s\tDepth:%d\tRange:\t%s\nChildren: " % (
            self.id, str(self.action), self.depth, str(self.ranges))
        for child in self.children:
            result += str(child.id) + " "
        result += "\nRules:\n"
        for rule in self.rules:
            result += str(rule) + "\n"
        if self.pushup_rules != None:
            result += "Pushup Rules:\n"
            for rule in self.pushup_rules:
                result += str(rule) + "\n"
        return result