예제 #1
0
def check_device_assign_pass(default_device,
                             default_device_id,
                             graph_op_metadata,
                             graph_op=OrderedSet(),
                             *args):
    """
    The Device assign pass should inject the metadata{device_id, device} as
    specified by the user for each op,
    if not specified then the default {device_id:0, device:cpu} should be
    inserted for each op.

    :param: default_device: string, the default device for each op,
            if not specified by user ex: "cpu"
    :param: default_device_id: string, the default device number for each op,
            if not specified by user ex: "0"
    :param: graph_op_metadata: dict, dictionary of list specifying  the expected
            metadata {device_id, device} for each op
    :param: graph_op: list of ops to do the graph traversal
    """
    with ExecutorFactory():
        expected_transformers = set()

        class MockHetr(object):
            def __init__(self):
                self.transformers = set()

            def register_transformer(self, transformer):
                self.transformers.add(transformer)

        hetr = MockHetr()
        obj = DeviceAssignPass(hetr, default_device, default_device_id)

        obj.do_pass(ops=graph_op)

        for op in graph_op_metadata.keys():
            assert op.metadata['device'] == graph_op_metadata[op][0]
            assert op.metadata['device_id'] == graph_op_metadata[op][1]
            if isinstance(graph_op_metadata[op][1], (list, tuple)):
                transformer = [
                    graph_op_metadata[op][0] + str(i)
                    for i in graph_op_metadata[op][1]
                ]
            else:
                transformer = graph_op_metadata[op][0] + str(
                    graph_op_metadata[op][1][0])
            assert op.metadata['transformer'] == transformer

            for device_id in graph_op_metadata[op][1]:
                expected_transformers.add(graph_op_metadata[op][0] + device_id)
        assert hetr.transformers == expected_transformers
예제 #2
0
    def __init__(self, **kwargs):
        super(HetrTransformer, self).__init__(**kwargs)

        self.my_pid = os.getpid()
        self.is_closed = False
        self.child_transformers = dict()
        self.transformer_list = list()
        self.transformers = set()
        self.send_nodes = OrderedSet()
        self.scatter_shared_queues = list()
        self.gather_shared_queues = list()
        self.hetr_passes = [
            DeviceAssignPass(default_device='numpy',
                             default_device_id=0,
                             transformers=self.transformers),
            CommunicationPass(self.send_nodes, self.scatter_shared_queues,
                              self.gather_shared_queues),
            DistributedPass(self.send_nodes, self.scatter_shared_queues,
                            self.gather_shared_queues),
            ChildTransformerPass(self.transformer_list)
        ]
        self.vizpass = None

        self.inits = OrderedSet()

        HetrTransformer.hetr_counter += 1
        assert HetrTransformer.hetr_counter <= 1
        assert HetrTransformer.hetr_counter >= 0
예제 #3
0
    def __init__(self, device='cpu', **kwargs):
        super(HetrTransformer, self).__init__(**kwargs)

        self.my_pid = os.getpid()
        self.is_closed = False
        self.child_transformers = dict()
        self.send_nodes = OrderedSet()
        self.graph_passes = [
            DeviceAssignPass(hetr=self,
                             default_device=device,
                             default_device_id=0),
            CommunicationPass(self.send_nodes),
            DistributedPass(self.send_nodes)
        ]

        hetr_server_path = os.path.dirname(
            os.path.realpath(__file__)) + "/cpu/hetr_server.py"
        hetr_server_num = os.getenv('HETR_SERVER_NUM')
        hetr_server_hostfile = os.getenv('HETR_SERVER_HOSTFILE')
        # Assumption is that hydra_persist processes are started on remote nodes
        # Otherwise, remove "-bootstrap persist" from the command line (it then uses ssh)
        if (hetr_server_num is not None) & (hetr_server_hostfile is not None):
            mpirun_str = "mpirun -n %s -ppn 1 -bootstrap persist -hostfile %s %s"\
                % (hetr_server_num, hetr_server_hostfile, hetr_server_path)
            subprocess.call(mpirun_str, shell=True)
            self.use_mlsl = True
        else:
            self.use_mlsl = False
예제 #4
0
def check_device_assign_pass(default_device,
                             default_device_id,
                             graph_op_metadata,
                             graph_op=OrderedSet(),
                             *args):
    """
    The Device assign pass should inject the metadata{device_id, device} as
    specified by the user for each op,
    if not specified then the default {device_id:0, device:numpy} should be
    inserted for each op.

    :param: default_device: string, the default device for each op,
            if not specified by user ex: "numpy"
    :param: default_device_id: string, the default device number for each op,
            if not specified by user ex: "0"
    :param: graph_op_metadata: dict, dictionary of list specifying  the expected
            metadata {device_id, device} for each op
    :param: graph_op: list of ops to do the graph traversal

    """
    transformer = ngt.make_transformer_factory('hetr')()

    transformers = set()
    expected_transformers = set()
    obj = DeviceAssignPass(default_device, default_device_id, transformers)

    obj.do_pass(graph_op, transformer)

    for op in graph_op_metadata.keys():
        assert op.metadata['device'] == graph_op_metadata[op][0]
        assert op.metadata['device_id'] == graph_op_metadata[op][1]
        assert op.metadata['transformer'] == graph_op_metadata[op][0] +  \
            str(graph_op_metadata[op][1])

        expected_transformers.add(op.metadata['transformer'])
    assert transformers == expected_transformers

    transformer.close()
예제 #5
0
    def __init__(self, **kwargs):
        super(HetrTransformer, self).__init__(**kwargs)

        self.my_pid = os.getpid()
        self.is_closed = False
        self.child_transformers = dict()
        self.send_nodes = OrderedSet()
        self.graph_passes = [
            DeviceAssignPass(hetr=self,
                             default_device='cpu',
                             default_device_id=0),
            CommunicationPass(self.send_nodes),
            DistributedPass(self.send_nodes)
        ]
예제 #6
0
    def __init__(self, device='cpu', **kwargs):
        super(HetrTransformer, self).__init__(**kwargs)

        self.default_device = device
        self.my_pid = os.getpid()
        self.is_closed = False
        self.child_transformers = dict()
        self.send_nodes = OrderedSet()
        self.graph_passes = [
            DeviceAssignPass(hetr=self,
                             default_device=device,
                             default_device_id=0),
            CommunicationPass(self.send_nodes),
            AxesUpdatePass()
        ]
        self.mpilauncher = MPILauncher()
예제 #7
0
    def __init__(self, device='cpu', **kwargs):
        super(HetrTransformer, self).__init__(**kwargs)

        self.my_pid = os.getpid()
        self.is_closed = False
        self.child_transformers = dict()
        self.send_nodes = OrderedSet()
        self.graph_passes = [
            DeviceAssignPass(hetr=self,
                             default_device=device,
                             default_device_id=0),
            CommunicationPass(self.send_nodes),
            DistributedPass(self.send_nodes)
        ]

        self.rpc_ports = get_available_ports()
        self.rpc_port_idx = 0
        self.mpilauncher = Launcher(self.rpc_ports)
        self.mpilauncher.launch()
예제 #8
0
    def __init__(self, device='cpu', **kwargs):
        super(HetrTransformer, self).__init__(**kwargs)
        if os.getenv('http_proxy') is not None:
            logger.warning(
                "http_proxy environment variable is set, unset http_proxy or "
                + "ensure that all targeted hosts (including localhost) " +
                "are specified in the no_proxy environment variable")

        self.default_device = device
        self.my_pid = os.getpid()
        self.is_closed = False
        self.child_transformers = dict()
        self.send_nodes = OrderedSet()
        self.graph_passes = [
            DeviceAssignPass(hetr=self,
                             default_device=device,
                             default_device_id=0),
            CommunicationPass(self.send_nodes),
            AxesUpdatePass()
        ]
        self.mpilauncher = MPILauncher()