def check_device_assign_pass(default_device, default_device_id, graph_op_metadata, graph_op=OrderedSet(), *args): """ The Device assign pass should inject the metadata{device_id, device} as specified by the user for each op, if not specified then the default {device_id:0, device:cpu} should be inserted for each op. :param: default_device: string, the default device for each op, if not specified by user ex: "cpu" :param: default_device_id: string, the default device number for each op, if not specified by user ex: "0" :param: graph_op_metadata: dict, dictionary of list specifying the expected metadata {device_id, device} for each op :param: graph_op: list of ops to do the graph traversal """ with ExecutorFactory(): expected_transformers = set() class MockHetr(object): def __init__(self): self.transformers = set() def register_transformer(self, transformer): self.transformers.add(transformer) hetr = MockHetr() obj = DeviceAssignPass(hetr, default_device, default_device_id) obj.do_pass(ops=graph_op) for op in graph_op_metadata.keys(): assert op.metadata['device'] == graph_op_metadata[op][0] assert op.metadata['device_id'] == graph_op_metadata[op][1] if isinstance(graph_op_metadata[op][1], (list, tuple)): transformer = [ graph_op_metadata[op][0] + str(i) for i in graph_op_metadata[op][1] ] else: transformer = graph_op_metadata[op][0] + str( graph_op_metadata[op][1][0]) assert op.metadata['transformer'] == transformer for device_id in graph_op_metadata[op][1]: expected_transformers.add(graph_op_metadata[op][0] + device_id) assert hetr.transformers == expected_transformers
def __init__(self, **kwargs): super(HetrTransformer, self).__init__(**kwargs) self.my_pid = os.getpid() self.is_closed = False self.child_transformers = dict() self.transformer_list = list() self.transformers = set() self.send_nodes = OrderedSet() self.scatter_shared_queues = list() self.gather_shared_queues = list() self.hetr_passes = [ DeviceAssignPass(default_device='numpy', default_device_id=0, transformers=self.transformers), CommunicationPass(self.send_nodes, self.scatter_shared_queues, self.gather_shared_queues), DistributedPass(self.send_nodes, self.scatter_shared_queues, self.gather_shared_queues), ChildTransformerPass(self.transformer_list) ] self.vizpass = None self.inits = OrderedSet() HetrTransformer.hetr_counter += 1 assert HetrTransformer.hetr_counter <= 1 assert HetrTransformer.hetr_counter >= 0
def __init__(self, device='cpu', **kwargs): super(HetrTransformer, self).__init__(**kwargs) self.my_pid = os.getpid() self.is_closed = False self.child_transformers = dict() self.send_nodes = OrderedSet() self.graph_passes = [ DeviceAssignPass(hetr=self, default_device=device, default_device_id=0), CommunicationPass(self.send_nodes), DistributedPass(self.send_nodes) ] hetr_server_path = os.path.dirname( os.path.realpath(__file__)) + "/cpu/hetr_server.py" hetr_server_num = os.getenv('HETR_SERVER_NUM') hetr_server_hostfile = os.getenv('HETR_SERVER_HOSTFILE') # Assumption is that hydra_persist processes are started on remote nodes # Otherwise, remove "-bootstrap persist" from the command line (it then uses ssh) if (hetr_server_num is not None) & (hetr_server_hostfile is not None): mpirun_str = "mpirun -n %s -ppn 1 -bootstrap persist -hostfile %s %s"\ % (hetr_server_num, hetr_server_hostfile, hetr_server_path) subprocess.call(mpirun_str, shell=True) self.use_mlsl = True else: self.use_mlsl = False
def check_device_assign_pass(default_device, default_device_id, graph_op_metadata, graph_op=OrderedSet(), *args): """ The Device assign pass should inject the metadata{device_id, device} as specified by the user for each op, if not specified then the default {device_id:0, device:numpy} should be inserted for each op. :param: default_device: string, the default device for each op, if not specified by user ex: "numpy" :param: default_device_id: string, the default device number for each op, if not specified by user ex: "0" :param: graph_op_metadata: dict, dictionary of list specifying the expected metadata {device_id, device} for each op :param: graph_op: list of ops to do the graph traversal """ transformer = ngt.make_transformer_factory('hetr')() transformers = set() expected_transformers = set() obj = DeviceAssignPass(default_device, default_device_id, transformers) obj.do_pass(graph_op, transformer) for op in graph_op_metadata.keys(): assert op.metadata['device'] == graph_op_metadata[op][0] assert op.metadata['device_id'] == graph_op_metadata[op][1] assert op.metadata['transformer'] == graph_op_metadata[op][0] + \ str(graph_op_metadata[op][1]) expected_transformers.add(op.metadata['transformer']) assert transformers == expected_transformers transformer.close()
def __init__(self, **kwargs): super(HetrTransformer, self).__init__(**kwargs) self.my_pid = os.getpid() self.is_closed = False self.child_transformers = dict() self.send_nodes = OrderedSet() self.graph_passes = [ DeviceAssignPass(hetr=self, default_device='cpu', default_device_id=0), CommunicationPass(self.send_nodes), DistributedPass(self.send_nodes) ]
def __init__(self, device='cpu', **kwargs): super(HetrTransformer, self).__init__(**kwargs) self.default_device = device self.my_pid = os.getpid() self.is_closed = False self.child_transformers = dict() self.send_nodes = OrderedSet() self.graph_passes = [ DeviceAssignPass(hetr=self, default_device=device, default_device_id=0), CommunicationPass(self.send_nodes), AxesUpdatePass() ] self.mpilauncher = MPILauncher()
def __init__(self, device='cpu', **kwargs): super(HetrTransformer, self).__init__(**kwargs) self.my_pid = os.getpid() self.is_closed = False self.child_transformers = dict() self.send_nodes = OrderedSet() self.graph_passes = [ DeviceAssignPass(hetr=self, default_device=device, default_device_id=0), CommunicationPass(self.send_nodes), DistributedPass(self.send_nodes) ] self.rpc_ports = get_available_ports() self.rpc_port_idx = 0 self.mpilauncher = Launcher(self.rpc_ports) self.mpilauncher.launch()
def __init__(self, device='cpu', **kwargs): super(HetrTransformer, self).__init__(**kwargs) if os.getenv('http_proxy') is not None: logger.warning( "http_proxy environment variable is set, unset http_proxy or " + "ensure that all targeted hosts (including localhost) " + "are specified in the no_proxy environment variable") self.default_device = device self.my_pid = os.getpid() self.is_closed = False self.child_transformers = dict() self.send_nodes = OrderedSet() self.graph_passes = [ DeviceAssignPass(hetr=self, default_device=device, default_device_id=0), CommunicationPass(self.send_nodes), AxesUpdatePass() ] self.mpilauncher = MPILauncher()