Пример #1
0
    def make_simple_model(self) -> Model:
        graph = Graph()

        # two inputs
        x = Input(
            'input',
            [1, 5, 5, 3],
            Float32(),
        )

        w = Constant(
            'weight',
            Float32(),
            np.zeros([1, 2, 2, 3]),
            dimension_format='NHWC',
        )

        # Conv
        conv = Conv('conv', [1, 4, 4, 1],
                    Float32(), {
                        'X': x,
                        'W': w
                    },
                    kernel_shape=[2, 2])

        # One output
        y = Output('output', [1, 4, 4, 1], Float32(), {'input': conv})

        # add ops to the graph
        graph.add_op_and_inputs(y)
        model = Model()
        model.graph = graph
        return model
Пример #2
0
def pass_constant_folding(graph: Graph) -> None:
    """Given a node N, if the value of each input of N is known at compilation time then N will be executed.
       The node N and its inputs will be replaced with a Constant node which holds the computed output of N.

    Args:
        graph (Graph): The input graph. It will be modified in-place.
        processed_nodes (list): The list of the processed nodes so far.
    
    """

    done = False
    processed_nodes = []
    while not done:
        exec_list = sort_graph(graph)
        processed_before_precompute = len(processed_nodes)
        to_be_removed = []

        for m in exec_list:
            if m in processed_nodes:
                continue

            # We want operators with inputs
            if not m.input_nodes:
                continue

            precomputable = True
            for input_node in m.input_nodes:
                if input_node.op_type != 'Constant':
                    precomputable = False

            if not precomputable:
                continue

            processed_nodes += m.input_nodes
            processed_nodes.append(m)

            data = m.run_forward()

            new_constant = Constant(m.name + '_new',
                                    m.dtype,
                                    data,
                                    dimension_format=m.dimension)
            graph.add_op(new_constant)

            # get nodes to be removed after being disconnected
            get_nodes_in_branch(m, None, to_be_removed)

            new_constant.add_outputs({'output': m.output_ops.values()})
            for output_name, consumer_list in m.output_ops.items():
                for consumer_node in consumer_list:
                    for input_name, input_node in consumer_node.input_ops.items(
                    ):
                        if input_node == m:
                            consumer_node.add_input(input_name, new_constant)
                            break

        for op in to_be_removed:
            graph.remove_op(op)

        done = len(processed_nodes) == processed_before_precompute
Пример #3
0
    def create_sample_graph(data1: np.ndarray) -> Graph:
        graph = Graph()

        # input
        x = Input('placeholder', [1, 5, 5, 3], Float32())

        # Conv1
        w1 = Constant('weight1', Float32(), data1)
        conv1 = Conv('conv1', [1, 4, 4, 3],
                     QUANTIZED_PACKED(), {
                         'X': x,
                         'W': w1
                     },
                     kernel_shape=[2, 2])
        conv1.is_quantized = True

        pool1 = SpaceToDepth('s2d', [1, 2, 2, 12], Float32(), {'input': conv1})

        # One output
        y = Output('output', [1, 2, 2, 12], Float32(), {'input': pool1})

        # add ops to the graph
        graph.add_op_and_inputs(y)

        return graph
Пример #4
0
    def create_expected_graph(data: np.ndarray) -> Graph:
        graph = Graph()

        # input
        x = Input('placeholder', [1, 5, 5, 3], Float32())

        # constant and internal nodes
        w = Constant('weight', Float32(), data)
        q = QTZ_binary_mean_scaling('qtz1', [1, 2, 2, 3], Float32(),
                                    {'input': w})

        # Conv
        conv = Conv('conv', [1, 4, 4, 3],
                    Float32(), {
                        'X': x,
                        'W': q
                    },
                    kernel_shape=[2, 2])

        # One output
        rs = Reshape('reshape', [1, 48], Float32(), {'data': conv})
        y = Output(
            'output',
            [1, 48],
            Float32(),
            {'input': rs},
        )

        # add ops to the graph
        graph.add_op_and_inputs(y)

        return graph
Пример #5
0
 def add_all_nodes(self, graph: Graph) -> None:
     visited: Set[Any] = set()
     added: Dict[str, Operator] = {}
     nodes_to_remove = []
     self.add_node_to_graph_recursive(self.out_lst[0], graph, visited, added, 'NHWC', nodes_to_remove)
     for node in nodes_to_remove:
         graph.remove_op(node)
Пример #6
0
    def create_sample_graph(data1: np.ndarray, data2: np.ndarray) -> Graph:
        graph = Graph()

        # input
        x = Input('placeholder', [1, 5, 5, 3], Float32())

        # Conv1
        w1 = Constant('weight1', Float32(), data1)
        conv1 = Conv('conv1', [1, 4, 4, 3], Float32(), {'X': x, 'W': w1}, kernel_shape=[2, 2])

        # activation quantizer
        s1 = Constant('aq_const1', Float32(), np.array(1))
        s2 = Constant('aq_const2', Float32(), np.array(2))
        aq = QTZ_linear_mid_tread_half('aqtz1', [1, 4, 4, 3], Float32(), {'X': conv1, 'Y': s1, 'Z': s2})

        # Conv2
        w2 = Constant('weight2', Float32(), data2)
        kq = QTZ_binary_mean_scaling('kqtz1', [1, 2, 2, 3], Float32(), {'input': w2})
        conv2 = Conv('conv2', [1, 3, 3, 3], Float32(), {'X': aq, 'W': kq}, kernel_shape=[2, 2])
        conv2.a_quantizer = [aq]
        conv2.quantizer = kq

        # One output
        y = Output('output', [1, 3, 3, 3], Float32(), {'input': conv2})

        # add ops to the graph
        graph.add_op_and_inputs(y)

        return graph
Пример #7
0
    def add_node_to_graph_recursive(self, current: Any, graph: Graph, visited: Set[Any], added: Dict[str, Operator],
                                    data_format: str, nodes_to_remove) \
            -> Operator:
        if current in visited:
            return added[current.name]
            # return current

        added_op_dic: Dict[str, Operator] = {}

        current_format, input_formats = self._get_format(current, data_format)
        inputs = self.find_inputs(current)
        for in_put, in_format in zip(inputs, input_formats):
            in_op = self.add_node_to_graph_recursive(in_put, graph, visited,
                                                     added, in_format,
                                                     nodes_to_remove)
            added_op_dic[in_op.name] = in_op

        op = self.create_new_op(current, added_op_dic, current_format,
                                input_formats, nodes_to_remove)

        graph.add_op(op)

        visited.add(current)
        added[op.name] = op
        return op
Пример #8
0
def pass_remove_identities(graph: Graph) -> None:
    """Removes those nodes of a Graph that satisfies the condition node.op_type() == Identity.

    Parameters
    ----------
    graph : Graph
        The input graph. It will be modified in-place.

    """
    exec_list = [n for n in sort_graph(graph) if n.op_type == 'Identity']
    to_be_removed = list()
    for m in exec_list:
        """skip all identity."""
        in_op = m.input_ops['input']
        out_ops = m.output_ops['output']
        for out_op in out_ops:
            for k, v in out_op.input_ops.items():
                if v == m:
                    # change the output's input to this identity's input
                    out_op.add_input(k, in_op)
                    # change the input's output to this identity's output
                    for k2, v2 in in_op.output_ops.items():
                        if m in v2:
                            v2.remove(m)
                            v2.append(out_op)
                            break
                    break

        to_be_removed.append(m)

    for op in to_be_removed:
        graph.remove_op(op)
Пример #9
0
    def add_all_nodes(self, graph: Graph) -> None:
        visited: Set[Any] = set()
        added: Dict[str, Operator] = {}
        nodes_to_remove = []
# fixme: default output format is NC/NHWC (ad-hoc workaround)
        rank_to_format = {2: 'NC', 4: 'NHWC'}
        self.add_node_to_graph_recursive(self.out_lst[0], graph, visited, added,
                                         rank_to_format[len(self.out_lst[0].get_shape())], nodes_to_remove)
        for node in nodes_to_remove:
            graph.remove_op(node)
Пример #10
0
def test_can_find_edge():
    graph = Graph()
    graph.addNode("start")
    graph.addNode("end")
    graph.addEdge("start", "end")
    output = graph.findEdge("start", "end")
    assert output == True
Пример #11
0
def test_can_remove_edge():
    graph = Graph()
    graph.addNode("start")
    graph.addNode("end")
    graph.addEdge("start", "end")
    output = graph.removeEdge("start", "end")
    assert output == True
Пример #12
0
 def __init__(self):
     """
     switches: a dictionary that maps switch names to components
     hosts: a dictionary that maps host names to components
     """
     self.name = None
     self.home = None
     self.handle = None
     self.switches = {}
     self.hosts = {}
     self.links = {}
     super(NetworkDriver, self).__init__()
     self.graph = Graph()
Пример #13
0
def flow_network(nodes, edges):
    root = nodes[0]
    destination = nodes[len(nodes) - 2]

    graph = Graph(root, destination, nodes, edges, True)
    edmonds_karp = Edmonds_Karp(graph)
    edmonds_karp.find_maximum_flow()
Пример #14
0
def breadth_first_search(nodes, edges):
    for node in nodes:
        root = node
        graph = Graph(root, nodes, edges, False)
        breadth_search = bfs(graph)
        breadth_search.search()
        reset(nodes, edges)
Пример #15
0
    def test_graphrunner_breadth_first(self) -> None:
        """Test code for GraphRunner, with the breadth-first mode."""
        graph = Graph()
        self.create_graph(graph)

        kwargs: Dict[str, List[str]] = {'backward': [], 'forward': []}
        runner = TestRunner(graph, depth_first=False, lazy=False)
        runner.run(**kwargs)

        lst1 = ['output', 'conv4', 'input3', 'conv3', 'input2', 'conv2', 'conv1', 'weight2', 'input1', 'weight1']
        self.assertEqual(kwargs['backward'], lst1,
                         'backward traversal failed in breadth-first mode.')

        lst2 = ['input3', 'input2', 'input1', 'weight1', 'weight2',
                'conv4', 'conv3', 'conv1', 'conv2', 'output']
        self.assertEqual(kwargs['forward'], lst2, 'forward traversal failed in breadth-first mode.')

        self.assertEqual(runner.message, [
            'start running.',
            'conv4: backward process',
            'conv3: backward process',
            'conv2: backward process',
            'conv1: backward process',
            'conv4: forward process',
            'conv3: forward process',
            'conv1: forward process',
            'conv2: forward process',
            'finished running.',
        ])

        print("GraphRunner bradth-first mode test passed!")
Пример #16
0
def floyd_warshall(nodes, edges):
    for node in nodes:
        print("root-> {} {}".format(node.get_sequence(), node.get_rotulo()))
        root = node
        graph = Graph(root, nodes, edges, True)
        floyd_warshall = FloydWarshall(graph)
        floyd_warshall.search()
        reset(nodes, edges)
Пример #17
0
def djikstra(nodes, edges):
    for node in nodes:
        print("root-> {} {}".format(node.get_sequence(), node.get_rotulo()))
        root = node
        graph = Graph(root, nodes, edges, True)
        spf = SPF(graph)
        spf.search_shortest_path()
        reset(nodes, edges)
Пример #18
0
    def create_sample_graph() -> Graph:
        graph = Graph()

        x = Input('placeholder', [2], Float32())

        s1 = Constant('potato_1', Float32(), np.array([1, 2]))
        s2 = Constant('potato_2', Float32(), np.array([1, 3]))
        add1 = Add('potatoes', [2], Float32(), {'A': s1, 'B': s2})
        add2 = Add('more_potatoes', [2], Float32(), {'A': x, 'B': add1})

        # One output
        y = Output('output', [2], Float32(), {'input': add2})

        # add ops to the graph
        graph.add_op_and_inputs(y)

        return graph
Пример #19
0
    def create_sample_graph(data1: np.ndarray, data2: np.ndarray) -> Graph:
        graph = Graph()

        # input
        x = Input('placeholder', [1, 5, 5, 3], Float32())

        # Conv1
        w1 = Constant('weight1', Float32(), data1)
        conv1 = Conv('conv1', [1, 4, 4, 3], Float32(), {'X': x, 'W': w1}, kernel_shape=[2, 2])

        # activation quantizer
        s1 = Constant('aq_const1', Int32(), np.array([2], dtype=np.int32))
        s2 = Constant('aq_const2', Float32(), np.array([2.0], dtype=np.float32))
        aq1 = QTZ_linear_mid_tread_half('aqtz1', [1, 4, 4, 3], Float32(), {'X': conv1, 'Y': s1, 'Z': s2})

        # Conv2
        w2 = Constant('weight2', Float32(), data2)
        kq = QTZ_binary_mean_scaling('kqtz1', [1, 2, 2, 3], Float32(), {'input': w2})
        conv2 = Conv('conv2', [1, 3, 3, 3], Float32(), {'X': aq1, 'W': kq}, kernel_shape=[2, 2])
        conv2.a_quantizer = [aq1]
        conv2.quantizer = kq
        conv2.is_quantized = True

        sc = Constant('bn_scale', Float32(), np.random.rand(3))
        be = Constant('bn_b', Float32(), np.random.rand(3))
        mu = Constant('bn_mu', Float32(), np.random.rand(3))
        va = Constant('bn_var', Float32(), np.random.rand(3))
        bn = BatchNormalization('bn', [1, 3, 3, 3], Float32(), {'X': conv2,
                                                                'scale': sc,
                                                                'B': be,
                                                                'mean': mu,
                                                                'var': va})

        # activation quantizer
        s3 = Constant('aq_const3', Int32(), np.array([2], dtype=np.int32))
        s4 = Constant('aq_const4', Float32(), np.array([2.0], dtype=np.float32))
        aq2 = QTZ_linear_mid_tread_half('aqtz2', [1, 3, 3, 3], Float32(), {'X': bn, 'Y': s3, 'Z': s4})

        # One output
        y = Output('output', [1, 3, 3, 3], Float32(), {'input': aq2})

        # add ops to the graph
        graph.add_op_and_inputs(y)

        return graph
Пример #20
0
def eulerian_path_finder(nodes, edges):
    for node in nodes:
        root = node
        graph = Graph(root, nodes, edges, True)
        eulerian_path_search = EulerianPath(graph)
        eulerian_path_search.search()
        reset(nodes, edges)

        print("#####")
Пример #21
0
    def create_transposed_graph(self, data: np.ndarray) -> Graph:
        graph = Graph()
        data = data.transpose([3, 2, 1, 0])

        # input
        x = Input('placeholder', [1, 5, 5, 3],
                  Float32(),
                  dimension_format='NHWC')

        # constant and internal nodes
        w = Constant('weight', Float32(), data, dimension_format='NHWC')

        i = Identity('identity1', [1, 2, 2, 3],
                     Float32(), {'input': w},
                     dimension_format='NHWC')

        q = QTZ_binary_mean_scaling('qtz1', [1, 2, 2, 3],
                                    Float32(), {'input': i},
                                    dimension_format='NHWC')

        # Conv
        conv = Conv('conv', [1, 4, 4, 3],
                    Float32(), {
                        'X': x,
                        'W': q
                    },
                    kernel_shape=[2, 2],
                    dimension_format='NHWC')

        rs = Reshape('reshape', [1, 48], Float32(), {'data': conv})

        # One output
        y = Output(
            'output',
            [1, 48],
            Float32(),
            {'input': rs},
        )

        # add ops to the graph
        graph.add_op_and_inputs(y)

        return graph
Пример #22
0
    def create_sample_graph_2(data1: np.ndarray) -> Graph:
        graph = Graph()

        # input
        x = Input('placeholder', [1, 5, 5, 3], Float32())

        # Conv1
        w1 = Constant('weight1', Float32(), data1)
        conv1 = Conv('conv1', [1, 4, 4, 3], Float32(), {'X': x, 'W': w1}, kernel_shape=[2, 2])

        s1 = Constant('const1', Float32(), np.zeros([1, 4, 4, 3]))
        add1 = Add('add', [1, 4, 4, 3], Float32(), {'A': conv1, 'B': s1})

        y = Output('output', [1, 4, 4, 3], Float32(), {'input': add1})

        # add ops to the graph
        graph.add_op_and_inputs(y)

        return graph
Пример #23
0
    def colour(self):
        power_set = self.graph.power_set()
        X = self.init_colour_array(power_set)

        for s in power_set:
            if s == 0:
                continue
            Xs = INFINITY

            nodes = self.get_nodes_in_set(power_set[s])
            edges = self.get_nodes_edges(nodes)
            g = Graph(nodes[0], nodes[len(nodes) - 1], nodes, edges, False)
            I = g.maximal_independent_sets()

            for independent_set in I:
                i = self.find_index(power_set, nodes, independent_set)
                Xi = X[i] + 1
                if Xi < Xs:
                    X[s] = Xi

        print("Número minimo de colorações: {}".format(X[len(X) - 1]))
Пример #24
0
    def create_sample_graph(data: np.ndarray) -> Graph:
        graph = Graph()

        # input
        x = Input('placeholder', [3, 5, 5, 1],
                  Float32(),
                  dimension_format='CWHN')

        # constant and internal nodes
        w = Constant('weight', Float32(), data, dimension_format='CWHN')
        i1 = Identity('identity1', [3, 2, 2, 1],
                      Float32(), {'input': w},
                      dimension_format='CWHN')
        q = QTZ_binary_mean_scaling('qtz1', [3, 2, 2, 1],
                                    Float32(), {'input': i1},
                                    dimension_format='CWHN')

        # Conv
        conv = Conv('conv', [3, 4, 4, 1],
                    Float32(), {
                        'X': x,
                        'W': q
                    },
                    kernel_shape=[2, 2],
                    dimension_format='CWHN')

        # One output
        rs = Reshape('reshape', [1, 48], Float32(), {'data': conv})
        y = Output(
            'output',
            [1, 48],
            Float32(),
            {'input': rs},
        )

        # add ops to the graph
        graph.add_op_and_inputs(y)

        return graph
 def __init__( self ):
     """
     switches: a dictionary that maps switch names to components
     hosts: a dictionary that maps host names to components
     """
     self.name = None
     self.home = None
     self.handle = None
     self.switches = {}
     self.hosts = {}
     self.links = {}
     super( NetworkDriver, self ).__init__()
     self.graph = Graph()
Пример #26
0
    def test_graph_conv(self) -> None:
        """Test code for making a simple graph with Conv."""
        graph = Graph()

        # two inputs
        x = Input(
            'input',
            [1, 5, 5, 3],
            Float32(),
        )

        w = Constant('weight', Float32(), np.zeros([1, 2, 2, 3]))

        # Conv
        conv = Conv(
            'conv',
            [1, 4, 4, 3],
            Float32(),
            {
                'X': x,
                'W': w
            },  # you can get these keys by 'Conv.input_names'
            kernel_shape=[2, 2])

        # One output
        y = Output(
            'output',
            [1, 4, 4, 3],
            Float32(),
            {'input': conv}  # you can get this key by 'Output.input_names'
        )

        # add ops to the graph
        graph.add_op(x)
        graph.add_op(w)
        graph.add_op(conv)
        graph.add_op(y)

        self.assertTrue(graph.check_nodes(),
                        "All inputs of operators must match their outputs.")
        print("Graph test passed!")
Пример #27
0
def generate_sv_single_video(context):
    try:
        detection_file, args = context
        output_pb_file = os.path.join(args.output_path, os.path.basename(detection_file).split('.')[0] + '.mp4.cut.pb')
        output_json_file = os.path.join(args.output_path, os.path.basename(detection_file).split('.')[0] + '.mp4.cut.mp4.final.reduced.json')
        cname = "single_view"
        moving_cameras = ["MOT17-05", "MOT17-06", "MOT17-07", "MOT17-10", "MOT17-11", "MOT17-12", "MOT17-13", "MOT17-14"]
        ############################################################
        # load and set configs
        video_name = os.path.basename(detection_file).split('-')[0] +  '-' + os.path.basename(detection_file).split('-')[1]
        #video_name = os.path.basename(detection_file).split('_')[0]
        if video_name in moving_cameras:
            config_file = os.path.join(args.config_path, "single_view_online_moving.json")
        else:
            config_file = os.path.join(args.config_path, "single_view_online_static.json")
        with open(config_file, 'r') as f:
            configs = json.loads(f.read())
        ############################################################
        # load data
        nodes = load_detections_to_nodes(detection_file,
                cname,
                configs["detection_confidence"],
                args.start_frame,
                args.start_frame + args.num_frames_process - 1, args.do_augmentation)

        opts = {} 
        graph = Graph(nodes,
                create_affinities(configs["affinity"], opts),
                cname)
        
        engine = create_algorithm(configs["algorithm"])

        output = engine(graph, cname)
        save_nodes_online_pbs(output, cname, output_pb_file)
        save_nodes_to_json(output, cname, output_json_file)
    except Exception as e:
        mlog.info(e)
        sys.exit()
Пример #28
0
def pass_lookup(graph: Graph) -> None:
    """Lookup.

    Parameters
    ----------
    graph : Graph
        The input graph. It will be modified in-place.
    """
    quantization_types = [
        'QTZ_binary_mean_scaling', 'QTZ_linear_mid_tread_half',
        'QTZ_binary_channel_wise_mean_scaling'
    ]

    to_be_removed = []
    exec_list = [
        n for n in sort_graph(graph) if n.op_type in quantization_types
    ]
    placeholder = [n for n in sort_graph(graph) if n.op_type in 'Input']

    for m in exec_list:
        quantizer = m

        p1 = quantizer.input_nodes[0]
        if p1.op_type != 'Reshape':
            continue
        p2 = p1.input_nodes[0]
        if p2.op_type != 'Reshape':
            continue
        p3 = p2.input_nodes[0]
        if p3.op_type != 'Gather':
            continue
        p4 = p3.input_nodes[0]
        if p4.op_type != 'Gather':
            continue
        gather_params = p4.input_nodes[0]
        if gather_params.rank != 2 or gather_params.shape[0] != 256:
            continue

        params = gather_params.data
        data = {'data': params}
        qtz_data = quantizer.run(**data)['data']

        word_size = 32
        lu_bitwidth = quantizer.nbit
        packer = Packer(lu_bitwidth, word_size)

        lsb = np.zeros((256, ), np.uint32)
        msb = np.zeros((256, ), np.uint32)

        idx = 0
        for p in qtz_data:
            data = packer.run(p.astype(np.float32), p.shape).flatten()
            lsb[idx] = data[0]
            msb[idx] = data[1]

            idx += 1

        pe_lsb = Constant('pe_lsb_new',
                          QUANTIZED_PACKED_KERNEL(),
                          lsb,
                          dimension_format='TC',
                          packed=True,
                          actual_shape=[256, word_size])
        pe_msb = Constant('pe_msb_new',
                          QUANTIZED_PACKED_KERNEL(),
                          msb,
                          dimension_format='TC',
                          packed=True,
                          actual_shape=[256, word_size])

        n, h, w, c = quantizer.shape
        shape = [1, h, w, 2, word_size]
        pe = Lookup('Lookup',
                    shape,
                    QUANTIZED_PACKED(), {
                        'input': placeholder[0],
                        'lsb': pe_lsb,
                        'msb': pe_msb
                    },
                    dimension_format='ChHWBCl')

        get_nodes_in_branch(quantizer, placeholder[0], to_be_removed)
        placeholder[0].remove_output('output')
        placeholder[0].add_output('output', pe)
        pe.add_outputs(quantizer.output_ops)

        output_op = quantizer.output_op_list[0]

        target_input_name = 'X'
        for input_name in output_op._input_names:
            if quantizer.equals(output_op._input_ops[input_name]):
                target_input_name = input_name
                break

        output_op.add_input(target_input_name, pe)

        graph.add_op(pe_lsb)
        graph.add_op(pe_msb)
        graph.add_op(pe)

    for op in to_be_removed:
        graph.remove_op(op)
Пример #29
0
def pass_pack_weights(graph: Graph) -> None:
    """Given a Quantized convolution node C, it will pack the weights of C into 32 bit words.
       If the node Q that apply quantization to the weights of C quantizes, for example, into 1 bit values
       then one 32 bit word will contain 32 weights.

    Parameters
    ----------
    graph : Graph
        The input graph. It will be modified in-place.
    """
    exec_list = [n for n in sort_graph(graph) if n.op_type == 'Conv']
    quantization_types = [
        'QTZ_binary_mean_scaling', 'QTZ_linear_mid_tread_half',
        'QTZ_binary_channel_wise_mean_scaling'
    ]

    word_size = 32
    weight_bitwidth = 1
    packer = Packer(weight_bitwidth, word_size)
    to_be_removed = []
    b = 32

    for m in exec_list:
        conv_node = m

        # check if this is a quantized convolution
        if not conv_node.quantizer or not conv_node.a_quantizer:
            continue

        # Check if we support this kind of quantizer
        weight_quantizer = conv_node.quantizer
        if weight_quantizer.op_type not in quantization_types:
            continue

        # Quantize the weights
        weight_quantizer.run_forward()

        def pad_to_multiple_of_b(tensor, axis, b):
            shape = list(tensor.shape)
            pad = (((shape[axis] + b - 1) // b) * b) - shape[axis]
            shape[axis] = pad
            return np.zeros(shape) if pad else None

        padded_data = np.copy(weight_quantizer.data)

        for axis in [0, 3]:
            pad_tensor = pad_to_multiple_of_b(padded_data, axis, b)
            if pad_tensor is not None:
                padded_data = np.append(padded_data, pad_tensor, axis=axis)

        tca_output = np.copy(padded_data)
        oc, kh, kw, kd = padded_data.shape[:]
        padded_data = padded_data.flatten()
        tca_output = tca_output.flatten()

        out_index = 0
        for g in range(oc // b):
            for p in range(kd // b):
                for h in range(kh):
                    for w in range(kw):
                        for o in range(b):
                            for d in range(b):
                                idx = g * (kw * kh * kd * b) + p * b + h * (
                                    kw * kd) + w * kd + o * (kw * kh * kd) + d
                                tca_output[out_index] = padded_data[idx]
                                out_index += 1

        kn2row_output = np.zeros(oc * kh * kw * kd)
        out_index = 0
        for h in range(kh):
            for w in range(kw):
                for o in range(oc):
                    for i in range(kd):
                        idx = o * kh * kw * kd + h * kw * kd + w * kd + i
                        kn2row_output[out_index] = padded_data[idx]
                        out_index += 1

        op_data = weight_quantizer.binarizer(padded_data)
        data = packer.run(op_data.astype(np.float32),
                          weight_quantizer.dimension)

        tca_binarized_data = weight_quantizer.binarizer(tca_output)
        tca_packed_data = packer.run(tca_binarized_data.astype(np.float32),
                                     weight_quantizer.dimension)

        kn2row_binarized_data = weight_quantizer.binarizer(kn2row_output)
        kn2row_data = packer.run(kn2row_binarized_data.astype(np.float32),
                                 weight_quantizer.dimension)

        shape = [oc, kh, kw, kd]
        tca_shape = [oc // b, kd // b, kh, kw, b, b]
        kn2row_shape = [kh, kw, oc, kd]

        # Create the new constant with the quantized weights
        quantized_constant = Constant(
            weight_quantizer.name + '_new',
            PackedUint32(),
            data=np.vectorize(lambda k: (~k) & ((0x1 << 32) - 1))(data),
            dimension_format="NHWC",
            transposed_dimension_format="OhIhHWOlIl",
            packed=True,
            actual_shape=shape,
            transposed_shape=tca_shape,
            transposed_data=[(~k) & ((0x1 << 32) - 1)
                             for k in tca_packed_data.flatten()],
            kn2row_data=[k for k in kn2row_data.flatten()],
            kn2row_shape=kn2row_shape,
            kn2row_dimension_format="HWNC")

        # get nodes to be removed after being disconnected
        get_nodes_in_branch(weight_quantizer, None, to_be_removed)

        # Add the constant to the graph and connect the new constant
        graph.add_op(quantized_constant)
        quantized_constant.add_outputs(weight_quantizer.output_ops)
        for output_name, consumer_list in weight_quantizer.output_ops.items():
            for consumer_node in consumer_list:
                for input_name, input_node in consumer_node.input_ops.items():
                    if input_node == weight_quantizer:
                        consumer_node.add_input(input_name, quantized_constant)
                        break

    for op in to_be_removed:
        graph.remove_op(op)
Пример #30
0
def pass_compute_thresholds(graph: Graph) -> None:
    """Given a Quantizer node Q:
         - if there is a backward path between Q and a convolution node and,
         - every node N of that path satisfies the condition N.is_monotonic and,
         - the convolution node C (the end of this path) is a quantized convolution
       then this pass construct an LUT per channel which maps a possible output value of the quantized convolution node
       C to the corresponding output of the quantization node Q. This effectively compress the path C -> ... -> Q
       into a list of LUTs that can be used during inference.

    Parameters
    ----------
    graph : Graph
        The input graph. It will be modified in-place.
    """
    exec_list = [
        n for n in sort_graph(graph)
        if n.op_type == 'QTZ_linear_mid_tread_half'
    ]
    to_be_removed = []
    for m in exec_list:
        # find a a backward path between the quantizer and the convolution ie. a path represented by a list [Q, ..., C]
        p = [m]
        while p[-1].op_type != 'Conv':
            non_variable_input = [
                inode for inode in p[-1].input_nodes
                if (not cast(Operator, inode).is_variable
                    and inode.is_monotonic) or inode.op_type == 'Conv'
            ]
            if len(non_variable_input) != 1:
                break
            p.append(non_variable_input[-1])

        if p[-1].op_type != 'Conv':
            continue
        activation_quantizer_node = p[0]
        conv_node = p[-1]

        # check if this is a quantized convolution
        if not conv_node.quantizer or not conv_node.a_quantizer:
            continue

        quantizer_conv_weights = conv_node.quantizer
        quantizer_conv_weights.run_forward_no_scaling_factor()
        scaling_factor = quantizer_conv_weights.scaling_factor

        # Getting the bit and max value
        nbits = []
        max_vs = []
        for aqtz in conv_node.a_quantizer:
            nbits.append(aqtz.nbit)
            max_vs.append(aqtz.max_v)
        if not (len(set(nbits)) == 1) and not (len(set(max_vs)) == 1):
            raise ValueError(
                f'bits {nbits} or max values {max_vs} are not consistent')
        else:
            nbit = nbits[0]
            max_v = max_vs[0]

        n = 2**nbit - 1
        ch = conv_node.channel
        # assume that the threshold values will be a 13-bit signed integer
        max_th_value = 2**12 - 1

        # The threshold_table is numpy array that holds the threshold values for all channels
        threshold_table = np.empty([ch, n + 1], dtype=np.int32)

        # Compute threshold (t0, t1, t2)
        th_val = [0.5 + i for i in range(n)]
        for th_id, th_v in enumerate(th_val):
            init_threshold = np.full(ch, th_v, dtype=np.float64)

            # run calculation in reverse order, for example, q -> bn -> scaling
            bn_nega_idx = []
            trans_th = {'data': init_threshold}
            for op in p[:-1]:
                trans_th = op.de_run(**trans_th)
                if op.op_type == 'BatchNormalization':
                    bn_scale = op.input_ops['scale'].data
                    bn_nega_idx = [
                        v for v in range(len(bn_scale)) if bn_scale[v] < 0
                    ]
            threshold = (trans_th['data'] *
                         np.float64(n)) / (np.float64(max_v) * scaling_factor)

            # take care of threshold values that are larger than 13-bit signed integer
            threshold[threshold > max_th_value] = max_th_value
            threshold[threshold < -max_th_value] = -max_th_value

            for ch_id, th_per_ch in enumerate(threshold):
                if quantizer_conv_weights.op_type == 'QTZ_binary_channel_wise_mean_scaling':
                    threshold_table[ch_id, th_id] = int(math.floor(th_per_ch)) \
                        if (scaling_factor[ch_id] < 0) ^ (ch_id in bn_nega_idx) \
                        else int(math.ceil(th_per_ch))
                else:
                    threshold_table[ch_id, th_id] = int(math.floor(th_per_ch)) \
                        if (scaling_factor < 0) ^ (ch_id in bn_nega_idx) \
                        else int(math.ceil(th_per_ch))

        for c in range(ch):
            threshold_table[c, -1] = 1 \
                if np.all(threshold_table[c, 1:-1] > threshold_table[c, :-2], axis=0) else -1
            if np.all(threshold_table[c, 1:-1] == threshold_table[c, :-2],
                      axis=0):
                threshold_table[c, -1] = 1
                threshold_table[c, 0:-1] = max_th_value

        # Put the thresholds into list
        conv_node.thresholds = threshold_table.flatten().tolist()

        # get nodes to be removed after being disconnected
        get_nodes_in_branch(activation_quantizer_node, conv_node,
                            to_be_removed)

        # Disconnect the outputs of the quantizer
        out_ops = activation_quantizer_node.output_ops['output']
        for output_node in out_ops:
            for input_name, input_node in output_node.input_ops.items():
                if input_node == activation_quantizer_node:
                    output_node.add_input(input_name, conv_node)

        # Disconnect the outputs of the conv
        conv_node.remove_output('Y')
        conv_node.add_outputs({'Y': out_ops})

    for op in to_be_removed:
        graph.remove_op(op)
    def CASE0(self, main):
        """
        Startup sequence:
        apply cell <name>
        git pull
        onos-package
        onos-verify-cell
        onos-uninstall
        onos-install
        onos-start-cli
        Set IPv6 cfg parameters for Neighbor Discovery
        start event scheduler
        start event listener
        """
        import time
        from threading import Lock, Condition
        from core.graph import Graph
        from tests.CHOTestMonkey.dependencies.elements.ONOSElement import Controller
        from tests.CHOTestMonkey.dependencies.EventGenerator import EventGenerator
        from tests.CHOTestMonkey.dependencies.EventScheduler import EventScheduler

        try:
            from tests.dependencies.ONOSSetup import ONOSSetup
            main.testSetUp = ONOSSetup()
        except ImportError:
            main.log.error("ONOSSetup not found exiting the test")
            main.cleanAndExit()
        main.testSetUp.envSetupDescription()

        try:
            onosPackage = main.params['TEST']['package']
            karafTimeout = main.params['TEST']['karafCliTimeout']
            main.enableIPv6 = main.params['TEST']['IPv6']
            main.enableIPv6 = True if main.enableIPv6 == "on" else False
            main.caseSleep = int(main.params['TEST']['caseSleep'])
            main.onosCell = main.params['ENV']['cellName']
            main.apps = main.params['ENV']['cellApps']
            main.controllers = []

            main.devices = []
            main.links = []
            main.hosts = []
            main.intents = []
            main.enabledEvents = {}
            for eventName in main.params['EVENT'].keys():
                if main.params['EVENT'][eventName]['status'] == 'on':
                    main.enabledEvents[int(main.params['EVENT'][eventName]
                                           ['typeIndex'])] = eventName
            print main.enabledEvents
            main.graph = Graph()
            main.eventScheduler = EventScheduler()
            main.eventGenerator = EventGenerator()
            main.variableLock = Lock()
            main.mininetLock = Lock()
            main.ONOSbenchLock = Lock()
            main.threadID = 0
            main.eventID = 0
            main.caseResult = main.TRUE
            stepResult = main.testSetUp.envSetup()
        except Exception as e:
            main.testSetUp.envSetupException(e)

        main.testSetUp.evnSetupConclusion(stepResult)

        setupResult = main.testSetUp.ONOSSetUp(main.Cluster,
                                               cellName=main.onosCell)
        for i in range(1, main.Cluster.numCtrls + 1):
            newController = Controller(i)
            newController.setCLI(main.Cluster.active(i - 1).CLI)
            main.controllers.append(newController)

        main.step("Set IPv6 cfg parameters for Neighbor Discovery")
        setIPv6CfgSleep = int(main.params['TEST']['setIPv6CfgSleep'])
        if main.enableIPv6:
            time.sleep(setIPv6CfgSleep)
            cfgResult1 = main.controllers[0].CLI.setCfg(
                "org.onosproject.net.neighbour.impl.NeighbourResolutionManager",
                "ndpEnabled", "true")
            time.sleep(setIPv6CfgSleep)
            cfgResult2 = main.controllers[0].CLI.setCfg(
                "org.onosproject.provider.host.impl.HostLocationProvider",
                "requestIpv6ND", "true")
        else:
            main.log.info(
                "Skipped setting IPv6 cfg parameters as it is disabled in params file"
            )
            cfgResult1 = main.TRUE
            cfgResult2 = main.TRUE
        cfgResult = cfgResult1 and cfgResult2
        utilities.assert_equals(
            expect=main.TRUE,
            actual=cfgResult,
            onpass="******",
            onfail="Failed to cfg set ipv6NeighborDiscovery")

        main.step("Start a thread for the scheduler")
        t = main.Thread(target=main.eventScheduler.startScheduler,
                        threadID=main.threadID,
                        name="startScheduler",
                        args=[])
        t.start()
        stepResult = main.TRUE
        with main.variableLock:
            main.threadID = main.threadID + 1

        utilities.assert_equals(expect=main.TRUE,
                                actual=stepResult,
                                onpass="******",
                                onfail="Test step FAIL")

        main.step(
            "Start a thread to listen to and handle network, ONOS and application events"
        )
        t = main.Thread(target=main.eventGenerator.startListener,
                        threadID=main.threadID,
                        name="startListener",
                        args=[])
        t.start()
        with main.variableLock:
            main.threadID = main.threadID + 1

        caseResult = setupResult and cfgResult
        utilities.assert_equals(expect=main.TRUE,
                                actual=caseResult,
                                onpass="******",
                                onfail="Set up test environment FAIL")
class NetworkDriver( CLI ):

    def __init__( self ):
        """
        switches: a dictionary that maps switch names to components
        hosts: a dictionary that maps host names to components
        """
        self.name = None
        self.home = None
        self.handle = None
        self.switches = {}
        self.hosts = {}
        self.links = {}
        super( NetworkDriver, self ).__init__()
        self.graph = Graph()

    def checkOptions( self, var, defaultVar ):
        if var is None or var == "":
            return defaultVar
        return var

    def connect( self, **connectargs ):
        """
        Creates ssh handle for the SDN network "bench".
        NOTE:
        The ip_address would come from the topo file using the host tag, the
        value can be an environment variable as well as a "localhost" to get
        the ip address needed to ssh to the "bench"
        """
        try:
            for key in connectargs:
                vars( self )[ key ] = connectargs[ key ]
            self.name = self.options[ 'name' ]
            try:
                if os.getenv( str( self.ip_address ) ) is not None:
                    self.ip_address = os.getenv( str( self.ip_address ) )
                else:
                    main.log.info( self.name +
                                   ": Trying to connect to " +
                                   self.ip_address )
            except KeyError:
                main.log.info( "Invalid host name," +
                               " connecting to local host instead" )
                self.ip_address = 'localhost'
            except Exception as inst:
                main.log.error( "Uncaught exception: " + str( inst ) )

            self.handle = super( NetworkDriver, self ).connect(
                user_name=self.user_name,
                ip_address=self.ip_address,
                port=self.port,
                pwd=self.pwd )

            if self.handle:
                main.log.info( "Connected to network bench node" )
                return self.handle
            else:
                main.log.info( "Failed to create handle" )
                return main.FALSE
        except pexpect.EOF:
            main.log.error( self.name + ": EOF exception found" )
            main.log.error( self.name + ":     " + self.handle.before )
            main.cleanAndExit()
        except Exception:
            main.log.exception( self.name + ": Uncaught exception!" )
            main.cleanAndExit()

    def disconnect( self ):
        """
        Called when test is complete to disconnect the handle.
        """
        response = main.TRUE
        try:
            if self.handle:
                self.handle.sendline( "exit" )
                self.handle.expect( "closed" )
        except pexpect.EOF:
            main.log.error( self.name + ": EOF exception found" )
            main.log.error( self.name + ":     " + self.handle.before )
        except Exception:
            main.log.exception( self.name + ": Connection failed to the host" )
            response = main.FALSE
        return response

    def connectToNet( self ):
        """
        Connect to an existing physical network by getting information
        of all switch and host components created
        """
        try:
            for key, value in main.componentDictionary.items():
                if hasattr( main, key ):
                    if value[ 'type' ] in [ 'MininetSwitchDriver', 'OFDPASwitchDriver' ]:
                        component = getattr( main, key )
                        shortName = component.options[ 'shortName' ]
                        localName = self.name + "-" + shortName
                        self.copyComponent( key, localName )
                        self.switches[ shortName ] = getattr( main, localName )
                    elif value[ 'type' ] in [ 'MininetHostDriver', 'HostDriver' ]:
                        component = getattr( main, key )
                        shortName = component.options[ 'shortName' ]
                        localName = self.name + "-" + shortName
                        self.copyComponent( key, localName )
                        self.hosts[ shortName ] = getattr( main, localName )
            main.log.debug( self.name + ": found switches: {}".format( self.switches ) )
            main.log.debug( self.name + ": found hosts: {}".format( self.hosts ) )
            return main.TRUE
        except Exception:
            main.log.error( self.name + ": failed to connect to network" )
            return main.FALSE

    def disconnectFromNet( self ):
        """
        Disconnect from the physical network connected
        """
        try:
            for key, value in main.componentDictionary.items():
                if hasattr( main, key ) and key.startswith( self.name + "-" ):
                    self.removeComponent( key )
            self.switches = {}
            self.hosts = {}
            return main.TRUE
        except Exception:
            main.log.error( self.name + ": failed to disconnect from network" )
            return main.FALSE

    def copyComponent( self, name, newName ):
        """
        Copy the component initialized from the .topo file
        The copied components are only supposed to be called within this driver
        Required:
            name: name of the component to be copied
            newName: name of the new component
        """
        try:
            main.componentDictionary[ newName ] = main.componentDictionary[ name ].copy()
            main.componentInit( newName )
        except Exception:
            main.log.exception( self.name + ": Uncaught exception!" )
            main.cleanAndExit()

    def removeHostComponent( self, name ):
        """
        Remove host component
        Required:
            name: name of the component to be removed
        """
        try:
            self.removeComponent( name )
        except Exception:
            main.log.exception( self.name + ": Uncaught exception!" )
            main.cleanAndExit()

    def removeComponent( self, name ):
        """
        Remove host/switch component
        Required:
            name: name of the component to be removed
        """
        try:
            component = getattr( main, name )
        except AttributeError:
            main.log.error( "Component " + name + " does not exist." )
            return main.FALSE
        try:
            # Disconnect from component
            component.disconnect()
            # Delete component
            delattr( main, name )
            # Delete component from ComponentDictionary
            del( main.componentDictionary[ name ] )
            return main.TRUE
        except Exception:
            main.log.exception( self.name + ": Uncaught exception!" )
            main.cleanAndExit()

    def createHostComponent( self, name ):
        """
        Creates host component with the same parameters as the one copied to local.
        Arguments:
            name - The string of the name of this component. The new component
                   will be assigned to main.<name> .
                   In addition, main.<name>.name = str( name )
        """
        try:
            # look to see if this component already exists
            getattr( main, name )
        except AttributeError:
            # namespace is clear, creating component
            localName = self.name + "-" + name
            main.componentDictionary[ name ] = main.componentDictionary[ localName ].copy()
            main.componentInit( name )
        except Exception:
            main.log.exception( self.name + ": Uncaught exception!" )
            main.cleanAndExit()
        else:
            # namespace is not clear!
            main.log.error( name + " component already exists!" )
            main.cleanAndExit()

    def connectInbandHosts( self ):
        """
        Connect to hosts using data plane IPs specified
        """
        result = main.TRUE
        try:
            for hostName, hostComponent in self.hosts.items():
                if hostComponent.options[ 'inband' ] == 'True':
                    main.log.info( self.name + ": connecting inband host " + hostName )
                    result = hostComponent.connectInband() and result
            return result
        except Exception:
            main.log.error( self.name + ": failed to connect to inband hosts" )
            return main.FALSE

    def disconnectInbandHosts( self ):
        """
        Terminate the connections to hosts using data plane IPs
        """
        result = main.TRUE
        try:
            for hostName, hostComponent in self.hosts.items():
                if hostComponent.options[ 'inband' ] == 'True':
                    main.log.info( self.name + ": disconnecting inband host " + hostName )
                    result = hostComponent.disconnectInband() and result
            return result
        except Exception:
            main.log.error( self.name + ": failed to disconnect inband hosts" )
            return main.FALSE

    def getSwitches( self, timeout=60, excludeNodes=[], includeStopped=False ):
        """
        Return a dictionary which maps short names to switch data
        If includeStopped is True, stopped switches will also be included
        """
        switches = {}
        for switchName, switchComponent in self.switches.items():
            if switchName in excludeNodes:
                continue
            if not includeStopped and not switchComponent.isup:
                continue
            dpid = switchComponent.dpid.replace( '0x', '' ).zfill( 16 )
            ports = switchComponent.ports
            swClass = 'Unknown'
            pid = None
            options = None
            switches[ switchName ] = { "dpid": dpid,
                                       "ports": ports,
                                       "swClass": swClass,
                                       "pid": pid,
                                       "options": options }
        return switches

    def getHosts( self, hostClass=None ):
        """
        Return a dictionary which maps short names to host data
        """
        hosts = {}
        for hostName, hostComponent in self.hosts.items():
            interfaces = hostComponent.interfaces
            hosts[ hostName ] = { "interfaces": interfaces }
        return hosts

    def updateLinks( self, timeout=60, excludeNodes=[] ):
        """
        Update self.links by getting up-to-date port information from
        switches
        """
        # TODO: also inlcude switch-to-host links
        self.links = {}
        for node1 in self.switches.keys():
            if node1 in excludeNodes:
                continue
            self.links[ node1 ] = {}
            self.switches[ node1 ].updatePorts()
            for port in self.switches[ node1 ].ports:
                if not port[ 'enabled' ]:
                    continue
                node2 = getattr( main, port[ 'node2' ] ).shortName
                if node2 in excludeNodes:
                    continue
                port1 = port[ 'of_port' ]
                port2 = port[ 'port2' ]
                if not self.links[ node1 ].get( node2 ):
                    self.links[ node1 ][ node2 ] = {}
                # Check if this link already exists
                if self.links.get( node2 ):
                    if self.links[ node2 ].get( node1 ):
                        if self.links[ node2 ].get( node1 ).get( port2 ):
                            assert self.links[ node2 ][ node1 ][ port2 ] == port1
                            continue
                self.links[ node1 ][ node2 ][ port1 ] = port2

    def getLinks( self, timeout=60, excludeNodes=[] ):
        """
        Return a list of links specify both node names and port numbers
        """
        self.updateLinks( timeout=timeout, excludeNodes=excludeNodes )
        links = []
        for node1, nodeLinks in self.links.items():
            for node2, ports in nodeLinks.items():
                for port1, port2 in ports.items():
                    links.append( { 'node1': node1, 'node2': node2,
                                    'port1': port1, 'port2': port2 } )
        return links

    def getMacAddress( self, host ):
        """
        Return MAC address of a host
        """
        import re
        try:
            hostComponent = self.hosts[ host ]
            response = hostComponent.ifconfig()
            pattern = r'HWaddr\s([0-9A-F]{2}[:-]){5}([0-9A-F]{2})'
            macAddressSearch = re.search( pattern, response, re.I )
            macAddress = macAddressSearch.group().split( " " )[ 1 ]
            main.log.info( self.name + ": Mac-Address of Host " + host + " is " + macAddress )
            return macAddress
        except Exception:
            main.log.error( self.name + ": failed to get host MAC address" )

    def runCmdOnHost( self, hostName, cmd ):
        """
        Run shell command on specified host and return output
        Required:
            hostName: name of the host e.g. "h1"
            cmd: command to run on the host
        """
        hostComponent = self.hosts[ hostName ]
        if hostComponent:
            return hostComponent.command( cmd )
        return None

    def assignSwController( self, sw, ip, port="6653", ptcp="" ):
        """
        Description:
            Assign switches to the controllers
        Required:
            sw - Short name of the switch specified in the .topo file, e.g. "s1".
            It can also be a list of switch names.
            ip - Ip addresses of controllers. This can be a list or a string.
        Optional:
            port - ONOS use port 6653, if no list of ports is passed, then
                   the all the controller will use 6653 as their port number
            ptcp - ptcp number, This can be a string or a list that has
                   the same length as switch. This is optional and not required
                   when using ovs switches.
        NOTE: If switches and ptcp are given in a list type they should have the
              same length and should be in the same order, Eg. sw=[ 's1' ... n ]
              ptcp=[ '6637' ... n ], s1 has ptcp number 6637 and so on.

        Return:
            Returns main.TRUE if switches are correctly assigned to controllers,
            otherwise it will return main.FALSE or an appropriate exception(s)
        """
        switchList = []
        ptcpList = None
        try:
            if isinstance( sw, types.StringType ):
                switchList.append( sw )
                if ptcp:
                    if isinstance( ptcp, types.StringType ):
                        ptcpList = [ ptcp ]
                    elif isinstance( ptcp, types.ListType ):
                        main.log.error( self.name + ": Only one switch is " +
                                        "being set and multiple PTCP is " +
                                        "being passed " )
                        return main.FALSE
                    else:
                        main.log.error( self.name + ": Invalid PTCP" )
                        return main.FALSE

            elif isinstance( sw, types.ListType ):
                switchList = sw
                if ptcp:
                    if isinstance( ptcp, types.ListType ):
                        if len( ptcp ) != len( sw ):
                            main.log.error( self.name + ": PTCP length = " +
                                            str( len( ptcp ) ) +
                                            " is not the same as switch" +
                                            " length = " +
                                            str( len( sw ) ) )
                            return main.FALSE
                        else:
                            ptcpList = ptcp
                    else:
                        main.log.error( self.name + ": Invalid PTCP" )
                        return main.FALSE
            else:
                main.log.error( self.name + ": Invalid switch type " )
                return main.FALSE

            assignResult = main.TRUE
            index = 0
            for switch in switchList:
                assigned = False
                switchComponent = self.switches[ switch ]
                if switchComponent:
                    ptcp = ptcpList[ index ] if ptcpList else ""
                    assignResult = assignResult and switchComponent.assignSwController( ip=ip, port=port, ptcp=ptcp )
                    assigned = True
                if not assigned:
                    main.log.error( self.name + ": Not able to find switch " + switch )
                    assignResult = main.FALSE
                index += 1
            return assignResult

        except Exception:
            main.log.exception( self.name + ": Uncaught exception!" )
            main.cleanAndExit()

    def pingall( self, protocol="IPv4", timeout=300, shortCircuit=False, acceptableFailed=0 ):
        """
        Description:
            Verifies the reachability of the hosts using ping command.
        Optional:
            protocol - use ping6 command if specified as "IPv6"
            timeout( seconds ) - How long to wait before breaking the pingall
            shortCircuit - Break the pingall based on the number of failed hosts ping
            acceptableFailed - Set the number of acceptable failed pings for the
                               function to still return main.TRUE
        Returns:
            main.TRUE if pingall completes with no pings dropped
            otherwise main.FALSE
        """
        import time
        import itertools
        try:
            timeout = int( timeout )
            main.log.info( self.name + ": Checking reachabilty to the hosts using ping" )
            failedPings = 0
            returnValue = main.TRUE
            ipv6 = True if protocol == "IPv6" else False
            startTime = time.time()
            hostPairs = itertools.permutations( list( self.hosts.values() ), 2 )
            for hostPair in list( hostPairs ):
                ipDst = hostPair[ 1 ].options[ 'ip6' ] if ipv6 else hostPair[ 1 ].options[ 'ip' ]
                pingResult = hostPair[ 0 ].ping( ipDst, ipv6=ipv6 )
                returnValue = returnValue and pingResult
                if ( time.time() - startTime ) > timeout:
                    returnValue = main.FALSE
                    main.log.error( self.name +
                                    ": Aborting pingall - " +
                                    "Function took too long " )
                    break
                if not pingResult:
                    failedPings = failedPings + 1
                    if failedPings > acceptableFailed:
                        returnValue = main.FALSE
                        if shortCircuit:
                            main.log.error( self.name +
                                            ": Aborting pingall - "
                                            + str( failedPings ) +
                                            " pings failed" )
                            break
            return returnValue
        except Exception:
            main.log.exception( self.name + ": Uncaught exception!" )
            main.cleanAndExit()

    def pingallHosts( self, hostList, wait=1 ):
        """
            Ping all specified IPv4 hosts

            Acceptable hostList:
                - [ 'h1','h2','h3','h4' ]

            Returns main.TRUE if all hosts specified can reach
            each other

            Returns main.FALSE if one or more of hosts specified
            cannot reach each other"""
        import time
        import itertools
        hostComponentList = []
        for hostName in hostList:
            hostComponent = self.hosts[ hostName ]
            if hostComponent:
                hostComponentList.append( hostComponent )
        try:
            main.log.info( "Testing reachability between specified hosts" )
            isReachable = main.TRUE
            pingResponse = "IPv4 ping across specified hosts\n"
            failedPings = 0
            hostPairs = itertools.permutations( list( hostComponentList ), 2 )
            for hostPair in list( hostPairs ):
                pingResponse += hostPair[ 0 ].options[ 'shortName' ] + " -> "
                ipDst = hostPair[ 1 ].options[ 'ip6' ] if ipv6 else hostPair[ 1 ].options[ 'ip' ]
                pingResult = hostPair[ 0 ].ping( ipDst, wait=int( wait ) )
                if pingResult:
                    pingResponse += hostPair[ 1 ].options[ 'shortName' ]
                else:
                    pingResponse += "X"
                    # One of the host to host pair is unreachable
                    isReachable = main.FALSE
                    failedPings += 1
                pingResponse += "\n"
            main.log.info( pingResponse + "Failed pings: " + str( failedPings ) )
            return isReachable
        except Exception:
            main.log.exception( self.name + ": Uncaught exception!" )
            main.cleanAndExit()

    def iperftcp( self, host1="h1", host2="h2", timeout=6 ):
        '''
        Creates an iperf TCP test between two hosts. Returns main.TRUE if test results
        are valid.
        Optional:
            timeout: The defualt timeout is 6 sec to allow enough time for a successful test to complete,
            and short enough to stop an unsuccessful test from quiting and cleaning up mininet.
        '''
        main.log.info( self.name + ": Simple iperf TCP test between two hosts" )
        # TODO: complete this function
        return main.TRUE

    def update( self ):
        return main.TRUE

    def verifyHostIp( self, hostList=[], prefix="", update=False ):
        """
        Description:
            Verify that all hosts have IP address assigned to them
        Optional:
            hostList: If specified, verifications only happen to the hosts
            in hostList
            prefix: at least one of the ip address assigned to the host
            needs to have the specified prefix
        Returns:
            main.TRUE if all hosts have specific IP address assigned;
            main.FALSE otherwise
        """
        try:
            if not hostList:
                hostList = self.hosts.keys()
            for hostName, hostComponent in self.hosts.items():
                if hostName not in hostList:
                    continue
                ipList = []
                ipa = hostComponent.ip()
                ipv4Pattern = r'inet ((?:[0-9]{1,3}\.){3}[0-9]{1,3})/'
                ipList += re.findall( ipv4Pattern, ipa )
                # It's tricky to make regex for IPv6 addresses and this one is simplified
                ipv6Pattern = r'inet6 ((?:[0-9a-fA-F]{1,4})?(?:[:0-9a-fA-F]{1,4}){1,7}(?:::)?(?:[:0-9a-fA-F]{1,4}){1,7})/'
                ipList += re.findall( ipv6Pattern, ipa )
                main.log.debug( self.name + ": IP list on host " + str( hostName ) + ": " + str( ipList ) )
                if not ipList:
                    main.log.warn( self.name + ": Failed to discover any IP addresses on host " + str( hostName ) )
                else:
                    if not any( ip.startswith( str( prefix ) ) for ip in ipList ):
                        main.log.warn( self.name + ": None of the IPs on host " + str( hostName ) + " has prefix " + str( prefix ) )
                    else:
                        main.log.debug( self.name + ": Found matching IP on host " + str( hostName ) )
                        hostList.remove( hostName )
            return main.FALSE if hostList else main.TRUE
        except KeyError:
            main.log.exception( self.name + ": host data not as expected: " + self.hosts.keys() )
            return None
        except pexpect.EOF:
            main.log.error( self.name + ": EOF exception found" )
            main.log.error( self.name + ":     " + self.handle.before )
            main.cleanAndExit()
        except Exception:
            main.log.exception( self.name + ": Uncaught exception" )
            return None

    def addRoute( self, host, dstIP, interface, ipv6=False ):
        """
        Add a route to host
        Ex: h1 route add -host 224.2.0.1 h1-eth0
        """
        try:
            if ipv6:
                cmd = "sudo route -A inet6 add "
            else:
                cmd = "sudo route add -host "
            cmd += str( dstIP ) + " " + str( interface )
            response = self.runCmdOnHost( host, cmd )
            main.log.debug( "response = " + response )
            return main.TRUE
        except pexpect.TIMEOUT:
            main.log.error( self.name + ": TIMEOUT exception found" )
            main.log.error( self.name + ":     " + self.handle.before )
            main.cleanAndExit()
        except pexpect.EOF:
            main.log.error( self.name + ": EOF exception found" )
            main.log.error( self.name + ":     " + self.handle.before )
            return main.FALSE
        except Exception:
            main.log.exception( self.name + ": Uncaught exception!" )
            main.cleanAndExit()

    def getIPAddress( self, host, proto='IPV4' ):
        """
        Returns IP address of the host
        """
        response = self.runCmdOnHost( host, "ifconfig" )
        pattern = ''
        if proto == 'IPV4':
            pattern = "inet\s(\d+\.\d+\.\d+\.\d+)\s\snetmask"
        else:
            pattern = "inet6\s([\w,:]*)/\d+\s\sprefixlen"
        ipAddressSearch = re.search( pattern, response )
        if not ipAddressSearch:
            return None
        main.log.info(
            self.name +
            ": IP-Address of Host " +
            host +
            " is " +
            ipAddressSearch.group( 1 ) )
        return ipAddressSearch.group( 1 )

    def getLinkRandom( self, timeout=60, nonCut=True, excludeNodes=[], skipLinks=[] ):
        """
        Randomly get a link from network topology.
        If nonCut is True, it gets a list of non-cut links (the deletion
        of a non-cut link will not increase the number of connected
        component of a graph) and randomly returns one of them, otherwise
        it just randomly returns one link from all current links.
        excludeNodes will be passed to getLinks and getGraphDict method.
        Any link that has either end included in skipLinks will be excluded.
        Returns the link as a list, e.g. [ 's1', 's2' ].
        """
        import random
        candidateLinks = []
        try:
            if not nonCut:
                links = self.getLinks( timeout=timeout, excludeNodes=excludeNodes )
                assert len( links ) != 0
                for link in links:
                    # Exclude host-switch link
                    if link[ 'node1' ].startswith( 'h' ) or link[ 'node2' ].startswith( 'h' ):
                        continue
                    candidateLinks.append( [ link[ 'node1' ], link[ 'node2' ] ] )
            else:
                graphDict = self.getGraphDict( timeout=timeout, useId=False,
                                               excludeNodes=excludeNodes )
                if graphDict is None:
                    return None
                self.graph.update( graphDict )
                candidateLinks = self.graph.getNonCutEdges()
            candidateLinks = [ link for link in candidateLinks
                               if link[0] not in skipLinks and link[1] not in skipLinks ]
            if candidateLinks is None:
                return None
            elif len( candidateLinks ) == 0:
                main.log.info( self.name + ": No candidate link for deletion" )
                return None
            else:
                link = random.sample( candidateLinks, 1 )
                return link[ 0 ]
        except KeyError:
            main.log.exception( self.name + ": KeyError exception found" )
            return None
        except AssertionError:
            main.log.exception( self.name + ": AssertionError exception found" )
            return None
        except Exception:
            main.log.exception( self.name + ": Uncaught exception" )
            return None

    def getSwitchRandom( self, timeout=60, nonCut=True, excludeNodes=[], skipSwitches=[] ):
        """
        Randomly get a switch from network topology.
        If nonCut is True, it gets a list of non-cut switches (the deletion
        of a non-cut switch will not increase the number of connected
        components of a graph) and randomly returns one of them, otherwise
        it just randomly returns one switch from all current switches in
        Mininet.
        excludeNodes will be pased to getSwitches and getGraphDict method.
        Switches specified in skipSwitches will be excluded.
        Returns the name of the chosen switch.
        """
        import random
        candidateSwitches = []
        try:
            if not nonCut:
                switches = self.getSwitches( timeout=timeout, excludeNodes=excludeNodes )
                assert len( switches ) != 0
                for switchName in switches.keys():
                    candidateSwitches.append( switchName )
            else:
                graphDict = self.getGraphDict( timeout=timeout, useId=False,
                                               excludeNodes=excludeNodes )
                if graphDict is None:
                    return None
                self.graph.update( graphDict )
                candidateSwitches = self.graph.getNonCutVertices()
            candidateSwitches = [ switch for switch in candidateSwitches if switch not in skipSwitches ]
            if candidateSwitches is None:
                return None
            elif len( candidateSwitches ) == 0:
                main.log.info( self.name + ": No candidate switch for deletion" )
                return None
            else:
                switch = random.sample( candidateSwitches, 1 )
                return switch[ 0 ]
        except KeyError:
            main.log.exception( self.name + ": KeyError exception found" )
            return None
        except AssertionError:
            main.log.exception( self.name + ": AssertionError exception found" )
            return None
        except Exception:
            main.log.exception( self.name + ": Uncaught exception" )
            return None

    def getGraphDict( self, timeout=60, useId=True, includeHost=False,
                      excludeNodes=[] ):
        """
        Return a dictionary which describes the latest network topology data as a
        graph.
        An example of the dictionary:
        { vertex1: { 'edges': ..., 'name': ..., 'protocol': ... },
          vertex2: { 'edges': ..., 'name': ..., 'protocol': ... } }
        Each vertex should at least have an 'edges' attribute which describes the
        adjacency information. The value of 'edges' attribute is also represented by
        a dictionary, which maps each edge (identified by the neighbor vertex) to a
        list of attributes.
        An example of the edges dictionary:
        'edges': { vertex2: { 'port': ..., 'weight': ... },
                   vertex3: { 'port': ..., 'weight': ... } }
        If useId == True, dpid/mac will be used instead of names to identify
        vertices, which is helpful when e.g. comparing network topology with ONOS
        topology.
        If includeHost == True, all hosts (and host-switch links) will be included
        in topology data.
        excludeNodes will be passed to getSwitches and getLinks methods to exclude
        unexpected switches and links.
        """
        # TODO: support excludeNodes
        graphDict = {}
        try:
            links = self.getLinks( timeout=timeout, excludeNodes=excludeNodes )
            portDict = {}
            switches = self.getSwitches( excludeNodes=excludeNodes )
            if includeHost:
                hosts = self.getHosts()
            for link in links:
                # TODO: support 'includeHost' argument
                if link[ 'node1' ].startswith( 'h' ) or link[ 'node2' ].startswith( 'h' ):
                    continue
                nodeName1 = link[ 'node1' ]
                nodeName2 = link[ 'node2' ]
                if not self.switches[ nodeName1 ].isup or not self.switches[ nodeName2 ].isup:
                    continue
                port1 = link[ 'port1' ]
                port2 = link[ 'port2' ]
                # Loop for two nodes
                for i in range( 2 ):
                    portIndex = port1
                    if useId:
                        node1 = 'of:' + str( switches[ nodeName1 ][ 'dpid' ] )
                        node2 = 'of:' + str( switches[ nodeName2 ][ 'dpid' ] )
                    else:
                        node1 = nodeName1
                        node2 = nodeName2
                    if node1 not in graphDict.keys():
                        if useId:
                            graphDict[ node1 ] = { 'edges': {},
                                                   'dpid': switches[ nodeName1 ][ 'dpid' ],
                                                   'name': nodeName1,
                                                   'ports': switches[ nodeName1 ][ 'ports' ],
                                                   'swClass': switches[ nodeName1 ][ 'swClass' ],
                                                   'pid': switches[ nodeName1 ][ 'pid' ],
                                                   'options': switches[ nodeName1 ][ 'options' ] }
                        else:
                            graphDict[ node1 ] = { 'edges': {} }
                    else:
                        # Assert node2 is not connected to any current links of node1
                        # assert node2 not in graphDict[ node1 ][ 'edges' ].keys()
                        pass
                    for port in switches[ nodeName1 ][ 'ports' ]:
                        if port[ 'of_port' ] == str( portIndex ):
                            # Use -1 as index for disabled port
                            if port[ 'enabled' ]:
                                graphDict[ node1 ][ 'edges' ][ node2 ] = { 'port': portIndex }
                            else:
                                graphDict[ node1 ][ 'edges' ][ node2 ] = { 'port': -1 }
                    # Swap two nodes/ports
                    nodeName1, nodeName2 = nodeName2, nodeName1
                    port1, port2 = port2, port1
            # Remove links with disabled ports
            linksToRemove = []
            for node, edges in graphDict.items():
                for neighbor, port in edges[ 'edges' ].items():
                    if port[ 'port' ] == -1:
                        linksToRemove.append( ( node, neighbor ) )
            for node1, node2 in linksToRemove:
                for i in range( 2 ):
                    if graphDict.get( node1 )[ 'edges' ].get( node2 ):
                        graphDict[ node1 ][ 'edges' ].pop( node2 )
                    node1, node2 = node2, node1
            return graphDict
        except KeyError:
            main.log.exception( self.name + ": KeyError exception found" )
            return None
        except AssertionError:
            main.log.exception( self.name + ": AssertionError exception found" )
            return None
        except pexpect.EOF:
            main.log.error( self.name + ": EOF exception found" )
            main.log.error( self.name + ":     " + self.handle.before )
            main.cleanAndExit()
        except Exception:
            main.log.exception( self.name + ": Uncaught exception" )
            return None

    def switch( self, **switchargs ):
        """
        start/stop a switch
        """
        args = utilities.parse_args( [ "SW", "OPTION" ], **switchargs )
        sw = args[ "SW" ] if args[ "SW" ] is not None else ""
        option = args[ "OPTION" ] if args[ "OPTION" ] is not None else ""
        try:
            switchComponent = self.switches[ sw ]
            if option == 'stop':
                switchComponent.stopOfAgent()
            elif option == 'start':
                switchComponent.startOfAgent()
            else:
                main.log.warn( self.name + ": Unknown switch command" )
                return main.FALSE
            return main.TRUE
        except KeyError:
            main.log.error( self.name + ": Not able to find switch [}".format( sw ) )
        except pexpect.TIMEOUT:
            main.log.error( self.name + ": TIMEOUT exception found" )
            main.log.error( self.name + ":     " + self.handle.before )
            return None
        except pexpect.EOF:
            main.log.error( self.name + ": EOF exception found" )
            main.log.error( self.name + ":     " + self.handle.before )
            main.cleanAndExit()
        except Exception:
            main.log.exception( self.name + ": Uncaught exception" )
            main.cleanAndExit()

    def discoverHosts( self, hostList=[], wait=1000, dstIp="6.6.6.6", dstIp6="1020::3fe" ):
        '''
        Hosts in hostList will do a single ARP/ND to a non-existent address for ONOS to
        discover them. A host will use arping/ndisc6 to send ARP/ND depending on if it
        has IPv4/IPv6 addresses configured.
        Optional:
            hostList: a list of names of the hosts that need to be discovered. If not
                      specified mininet will send ping from all the hosts
            wait: timeout for ARP/ND in milliseconds
            dstIp: destination address used by IPv4 hosts
            dstIp6: destination address used by IPv6 hosts
        Returns:
            main.TRUE if all packets were successfully sent. Otherwise main.FALSE
        '''
        try:
            hosts = self.getHosts()
            if not hostList:
                hostList = hosts.keys()
            discoveryResult = main.TRUE
            for host in hostList:
                flushCmd = ""
                cmd = ""
                if self.getIPAddress( host ):
                    flushCmd = "sudo ip neigh flush all"
                    cmd = "arping -c 1 -w {} {}".format( wait, dstIp )
                    main.log.debug( "Sending IPv4 arping from host {}".format( host ) )
                elif self.getIPAddress( host, proto='IPV6' ):
                    flushCmd = "sudo ip -6 neigh flush all"
                    intf = hosts[host]['interfaces'][0]['name']
                    cmd = "ndisc6 -r 1 -w {} {} {}".format( wait, dstIp6, intf )
                    main.log.debug( "Sending IPv6 ND from host {}".format( host ) )
                else:
                    main.log.warn( "No IP addresses configured on host {}, skipping discovery".format( host ) )
                    discoveryResult = main.FALSE
                if cmd:
                    self.runCmdOnHost( host, flushCmd )
                    self.runCmdOnHost( host, cmd )
            return discoveryResult
        except Exception:
            main.log.exception( self.name + ": Uncaught exception!" )
            main.cleanAndExit()