def build(self, *args, trace_autograd=False): """Builds the plan. First, run the function to be converted in a plan in a context which activates the tracing and record the actions in trace.logs Second, store the result ids temporarily to helper ordering the output placeholders at return time Third, loop through the trace logs and replace the tensors found in the actions logged by PlaceHolders. Record those actions in plan.actions Args: args: Input arguments to run the plan """ # Reset previous build self.role.reset() def build_nested_arg(arg, leaf_function): if isinstance(arg, list): return [build_nested_arg(obj, leaf_function) for obj in arg] elif isinstance(arg, tuple): return tuple( [build_nested_arg(obj, leaf_function) for obj in arg]) elif isinstance(arg, dict): return { k: build_nested_arg(v, leaf_function) for k, v in arg.items() } else: return leaf_function(arg) # Enable tracing self.toggle_tracing(True) self.is_building = True # typecheck self.input_types = NestedTypeWrapper(args) # Run once to build the plan if trace_autograd: # Wrap arguments that require gradients with AutogradTensor, # to be able to trace autograd operations args = build_nested_arg( args, lambda x: AutogradTensor().on(x, wrap=False) if isinstance(x, FrameworkTensor) and x.requires_grad else x, ) # Add Placeholder after AutogradTensor in the chain # so that all operations that happen inside AutogradTensor are recorded by Placeholder args_placeholders = build_nested_arg( args, lambda x: PlaceHolder.insert( x, AutogradTensor, role=self.role, tracing=True), ) else: # Add Placeholder on top of each arg args = args_placeholders = build_nested_arg( args, lambda x: PlaceHolder.create_from( x, role=self.role, tracing=True), ) # Add state to args if needed if self.include_state: args += (self.state, ) # Check the plan arguments to see what framework wrappers we might need to send to the plan framework_kwargs = {} forward_args = inspect.getfullargspec(self.forward).args for f_name, wrap_framework_func in Plan._wrapped_frameworks.items(): if f_name in forward_args: framework_kwargs[f_name] = wrap_framework_func( self.role, self.owner) results = self.forward(*args, **framework_kwargs) # Disable tracing self.toggle_tracing(False) self.is_building = False # Register inputs in role self.role.register_inputs(args_placeholders) # Register outputs in role if isinstance(results, (tuple, list)): results_placeholders = tuple( PlaceHolder.extract(result) for result in results) else: results_placeholders = PlaceHolder.extract(results) self.role.register_outputs(results_placeholders) self.is_built = True # Build registered translations for translator in Plan._build_translators: try: self.add_translation(translator) self.translations.append(translator) except: warnings.warn( f"Failed to translate Plan with {translator.__name__}") return results
def build(self, *args, trace_autograd=False): """Builds the plan. First, run the function to be converted in a plan in a context which activates the tracing and record the actions in trace.logs Second, store the result ids temporarily to helper ordering the output placeholders at return time Third, loop through the trace logs and replace the tensors found in the actions logged by PlaceHolders. Record those actions in plan.actions Args: args: Input arguments to run the plan """ # Enable tracing self.toggle_tracing(True) self.is_building = True if trace_autograd: # Wrap arguments that require gradients with AutogradTensor, # to be able to trace autograd operations args = tuple( AutogradTensor().on(arg, wrap=False) if isinstance(arg, FrameworkTensor) and arg.requires_grad else arg for arg in args) # Add Placeholder after AutogradTensor in the chain # so that all operations that happen inside AutogradTensor are recorded by Placeholder args_placeholders = tuple( PlaceHolder.insert(arg, AutogradTensor, owner=sy.local_worker, role=self.role, tracing=True) for arg in args) else: # Add Placeholder on top of each arg args = args_placeholders = tuple( PlaceHolder.create_from( arg, owner=sy.local_worker, role=self.role, tracing=True) for arg in args) # Add state to args if needed if self.include_state: args += (self.state, ) with trace(framework_packages["torch"], self.role, self.owner) as wrapped_torch: # Look for framework kwargs framework_kwargs = {} forward_args = inspect.getfullargspec(self.forward).args if "torch" in forward_args: framework_kwargs["torch"] = wrapped_torch results = self.forward(*args, **framework_kwargs) # Disable tracing self.toggle_tracing(False) self.is_building = False # Register inputs in role self.role.register_inputs(args_placeholders) # Register outputs in role if isinstance(results, (tuple, list)): results_placeholders = tuple( PlaceHolder.extract(result) for result in results) else: results_placeholders = PlaceHolder.extract(results) self.role.register_outputs(results_placeholders) self.is_built = True # Build registered translations for translator in Plan._build_translators: try: self.add_translation(translator) self.translations.append(translator) except: warnings.warn( f"Failed to translate Plan with {translator.__name__}") return results