Beispiel #1
0
def check_device_assign_pass(default_device,
                             default_device_id,
                             graph_op_metadata,
                             graph_op=OrderedSet(),
                             *args):
    """
    The Device assign pass should inject the metadata{device_id, device} as
    specified by the user for each op,
    if not specified then the default {device_id:0, device:cpu} should be
    inserted for each op.

    :param: default_device: string, the default device for each op,
            if not specified by user ex: "cpu"
    :param: default_device_id: string, the default device number for each op,
            if not specified by user ex: "0"
    :param: graph_op_metadata: dict, dictionary of list specifying  the expected
            metadata {device_id, device} for each op
    :param: graph_op: list of ops to do the graph traversal
    """
    with ExecutorFactory():
        expected_transformers = set()

        class MockHetr(object):
            def __init__(self):
                self.transformers = set()

            def register_transformer(self, transformer):
                self.transformers.add(transformer)

        hetr = MockHetr()
        obj = DeviceAssignPass(hetr, default_device, default_device_id)

        obj.do_pass(ops=graph_op)

        for op in graph_op_metadata.keys():
            assert op.metadata['device'] == graph_op_metadata[op][0]
            assert op.metadata['device_id'] == graph_op_metadata[op][1]
            if isinstance(graph_op_metadata[op][1], (list, tuple)):
                transformer = [
                    graph_op_metadata[op][0] + str(i)
                    for i in graph_op_metadata[op][1]
                ]
            else:
                transformer = graph_op_metadata[op][0] + str(
                    graph_op_metadata[op][1][0])
            assert op.metadata['transformer'] == transformer

            for device_id in graph_op_metadata[op][1]:
                expected_transformers.add(graph_op_metadata[op][0] + device_id)
        assert hetr.transformers == expected_transformers
def check_device_assign_pass(default_device,
                             default_device_id,
                             graph_op_metadata,
                             graph_op=OrderedSet(),
                             *args):
    """
    The Device assign pass should inject the metadata{device_id, device} as
    specified by the user for each op,
    if not specified then the default {device_id:0, device:numpy} should be
    inserted for each op.

    :param: default_device: string, the default device for each op,
            if not specified by user ex: "numpy"
    :param: default_device_id: string, the default device number for each op,
            if not specified by user ex: "0"
    :param: graph_op_metadata: dict, dictionary of list specifying  the expected
            metadata {device_id, device} for each op
    :param: graph_op: list of ops to do the graph traversal

    """
    transformer = ngt.make_transformer_factory('hetr')()

    transformers = set()
    expected_transformers = set()
    obj = DeviceAssignPass(default_device, default_device_id, transformers)

    obj.do_pass(graph_op, transformer)

    for op in graph_op_metadata.keys():
        assert op.metadata['device'] == graph_op_metadata[op][0]
        assert op.metadata['device_id'] == graph_op_metadata[op][1]
        assert op.metadata['transformer'] == graph_op_metadata[op][0] +  \
            str(graph_op_metadata[op][1])

        expected_transformers.add(op.metadata['transformer'])
    assert transformers == expected_transformers

    transformer.close()