def run_party(cid, func, rank, world_size, master_addr, master_port, func_args, func_kwargs): """Start crypten party localy and run computation. Args: cid (int): CrypTen computation id. func (function): computation to be done. rank (int): rank of the crypten party. world_size (int): number of crypten parties involved in the computation. master_addr (str): IP address of the master party (party with rank 0). master_port (int or str): port of the master party (party with rank 0). func_args (list): arguments to be passed to func. func_kwargs (dict): keyword arguments to be passed to func. Returns: The return value of func. """ process, queue = _new_party(cid, func, rank, world_size, master_addr, master_port, func_args, func_kwargs) was_initialized = DistributedCommunicator.is_initialized() if was_initialized: crypten.uninit() process.start() # wait for response res = queue.get() if was_initialized: crypten.init() return res
def wrapper(*args, **kwargs): rendezvous_file = tempfile.NamedTemporaryFile(delete=True).name queue = multiprocessing.Queue() processes = [ multiprocessing.Process( target=_launch, args=(func, rank, world_size, rendezvous_file, queue, args, kwargs), ) for rank in range(world_size) ] # This process will be forked and we need to re-initialize the # communicator in the children. If the parent process happened to # call crypten.init(), which might be valid in a Jupyter notebook # for instance, then the crypten.init() call on the children # process will not do anything. The call to uninit here makes sure # we actually get to initialize the communicator on the child # process. An alternative fix for this issue would be to use spawn # instead of fork, but we run into issues serializing the function # in that case. was_initialized = DistributedCommunicator.is_initialized() if was_initialized: crypten.uninit() for process in processes: process.start() for process in processes: process.join() if was_initialized: crypten.init() successful = [process.exitcode == 0 for process in processes] if not all(successful): logging.error("One of the parties failed. Check past logs") return None return_values = [] while not queue.empty(): return_values.append(queue.get()) return [value for _, value in sorted(return_values, key=itemgetter(0))]
def wrapper(*args, **kwargs): # TODO: # - check if workers are reachable / they can handle the computation # - check return code of processes for possible failure if len(workers) != len(set(worker.id for worker in workers)): # noqa: C401 raise RuntimeError( "found workers with same ID but IDs must be unique") if model is not None: if not isinstance(model, th.nn.Module): raise TypeError("model must be a torch.nn.Module") if dummy_input is None: raise ValueError( "must provide dummy_input when model is set") if not isinstance(dummy_input, th.Tensor): raise TypeError("dummy_input must be a torch.Tensor") onnx_model = utils.pytorch_to_onnx(model, dummy_input) else: onnx_model = None crypten_model = None if onnx_model is None else utils.onnx_to_crypten( onnx_model) world_size = len(workers) manager = multiprocessing.Manager() return_values = manager.dict( {rank: None for rank in range(world_size)}) rank_to_worker_id = dict( zip(range(0, len(workers)), [worker.id for worker in workers])) # TODO: run ttp in a specified worker # if crypten.mpc.ttp_required(): # ttp_process, _ = _new_party( # crypten.mpc.provider.TTPServer, # world_size, # world_size, # master_addr, # master_port, # (), # {}, # ) # ttp_process.start() if isinstance(func, sy.Plan): plan = func # This is needed because at building we use a set of methods defined in syft # (ex: load) hook_plan_building() was_initialized = DistributedCommunicator.is_initialized() if not was_initialized: crypten.init() # We can build the plan only using a crypten model such that the actions # traced inside the plan would know about it's existance if crypten_model is None: plan.build() else: plan.build(crypten_model) if not was_initialized: crypten.uninit() unhook_plan_building() # Mark the plan so the other workers will use that tag to retrieve the plan plan.tags = ["crypten_plan"] for worker in workers: plan.send(worker) msg = CryptenInitPlan( (rank_to_worker_id, world_size, master_addr, master_port), onnx_model) else: # func jail_runner = jail.JailRunner(func=func) ser_jail_runner = jail.JailRunner.simplify(jail_runner) msg = CryptenInitJail( (rank_to_worker_id, world_size, master_addr, master_port), ser_jail_runner, onnx_model, ) # Send messages to other workers so they start their parties threads = [] for i in range(len(workers)): rank = i thread = multiprocessing.Process( target=_send_party_info, args=(workers[i], rank, msg, return_values, crypten_model), ) thread.start() threads.append(thread) # wait for workers running the parties return a response for thread in threads: thread.join() return return_values
def wrapper(*args, **kwargs): # TODO: # - check if workers are reachable / they can handle the computation # - check return code of processes for possible failure if model is not None: if not isinstance(model, th.nn.Module): raise TypeError("model must be a torch.nn.Module") if dummy_input is None: raise ValueError( "must provide dummy_input when model is set") if not isinstance(dummy_input, th.Tensor): raise TypeError("dummy_input must be a torch.Tensor") onnx_model = utils.pytorch_to_onnx(model, dummy_input) else: onnx_model = None crypten_model = None if onnx_model is None else utils.onnx_to_crypten( onnx_model) world_size = len(workers) + 1 return_values = {rank: None for rank in range(world_size)} if isinstance(func, sy.Plan): using_plan = True plan = func # This is needed because at building we use a set of methods defined in syft (ex: load) hook_plan_building() crypten.init() plan.build() crypten.uninit() unhook_plan_building() # Mark the plan so the other workers will use that tag to retrieve the plan plan.tags = ["crypten_plan"] for worker in workers: plan.send(worker) jail_or_plan = plan else: # func using_plan = False jail_runner = jail.JailRunner(func=func, model=crypten_model) ser_jail_runner = jail.JailRunner.simplify(jail_runner) jail_or_plan = jail_runner rank_to_worker_id = dict( zip(range(1, len(workers) + 1), [worker.id for worker in workers])) sy.local_worker._set_rank_to_worker_id(rank_to_worker_id) # Start local party process, queue = _new_party(jail_or_plan, 0, world_size, master_addr, master_port, (), {}) was_initialized = DistributedCommunicator.is_initialized() if was_initialized: crypten.uninit() process.start() # Run TTP if required # TODO: run ttp in a specified worker if crypten.mpc.ttp_required(): ttp_process, _ = _new_party( crypten.mpc.provider.TTPServer, world_size, world_size, master_addr, master_port, (), {}, ) ttp_process.start() # Send messages to other workers so they start their parties threads = [] for i in range(len(workers)): rank = i + 1 if using_plan: msg = CryptenInitPlan((rank_to_worker_id, world_size, master_addr, master_port)) else: # jail msg = CryptenInitJail( (rank_to_worker_id, world_size, master_addr, master_port), ser_jail_runner, onnx_model, ) thread = threading.Thread(target=_send_party_info, args=(workers[i], rank, msg, return_values)) thread.start() threads.append(thread) # Wait for local party and sender threads # Joining the process blocks! But queue.get() can also wait for the party # and it works fine. # process.join() -> blocks local_party_result = queue.get() return_values[0] = utils.unpack_values(local_party_result, crypten_model) for thread in threads: thread.join() if was_initialized: crypten.init() return return_values