Example #1
0
    def build(self, *args, trace_autograd=False):
        """Builds the plan.

        First, run the function to be converted in a plan in a context which
        activates the tracing and record the actions in trace.logs

        Second, store the result ids temporarily to helper ordering the output
        placeholders at return time

        Third, loop through the trace logs and replace the tensors found in the
        actions logged by PlaceHolders. Record those actions in
        plan.actions

        Args:
            args: Input arguments to run the plan
        """
        # Reset previous build
        self.role.reset()

        def build_nested_arg(arg, leaf_function):
            if isinstance(arg, list):
                return [build_nested_arg(obj, leaf_function) for obj in arg]
            elif isinstance(arg, tuple):
                return tuple(
                    [build_nested_arg(obj, leaf_function) for obj in arg])
            elif isinstance(arg, dict):
                return {
                    k: build_nested_arg(v, leaf_function)
                    for k, v in arg.items()
                }
            else:
                return leaf_function(arg)

        # Enable tracing
        self.toggle_tracing(True)
        self.is_building = True

        # typecheck
        self.input_types = NestedTypeWrapper(args)

        # Run once to build the plan
        if trace_autograd:
            # Wrap arguments that require gradients with AutogradTensor,
            # to be able to trace autograd operations
            args = build_nested_arg(
                args,
                lambda x: AutogradTensor().on(x, wrap=False)
                if isinstance(x, FrameworkTensor) and x.requires_grad else x,
            )
            # Add Placeholder after AutogradTensor in the chain
            # so that all operations that happen inside AutogradTensor are recorded by Placeholder
            args_placeholders = build_nested_arg(
                args,
                lambda x: PlaceHolder.insert(
                    x, AutogradTensor, role=self.role, tracing=True),
            )
        else:
            # Add Placeholder on top of each arg
            args = args_placeholders = build_nested_arg(
                args,
                lambda x: PlaceHolder.create_from(
                    x, role=self.role, tracing=True),
            )

        # Add state to args if needed
        if self.include_state:
            args += (self.state, )

        # Check the plan arguments to see what framework wrappers we might need to send to the plan
        framework_kwargs = {}

        forward_args = inspect.getfullargspec(self.forward).args
        for f_name, wrap_framework_func in Plan._wrapped_frameworks.items():
            if f_name in forward_args:
                framework_kwargs[f_name] = wrap_framework_func(
                    self.role, self.owner)

        results = self.forward(*args, **framework_kwargs)

        # Disable tracing
        self.toggle_tracing(False)
        self.is_building = False

        # Register inputs in role
        self.role.register_inputs(args_placeholders)

        # Register outputs in role
        if isinstance(results, (tuple, list)):
            results_placeholders = tuple(
                PlaceHolder.extract(result) for result in results)
        else:
            results_placeholders = PlaceHolder.extract(results)
        self.role.register_outputs(results_placeholders)

        self.is_built = True

        # Build registered translations
        for translator in Plan._build_translators:
            try:
                self.add_translation(translator)
                self.translations.append(translator)
            except:
                warnings.warn(
                    f"Failed to translate Plan with {translator.__name__}")

        return results
Example #2
0
    def build(self, *args, trace_autograd=False):
        """Builds the plan.

        First, run the function to be converted in a plan in a context which
        activates the tracing and record the actions in trace.logs

        Second, store the result ids temporarily to helper ordering the output
        placeholders at return time

        Third, loop through the trace logs and replace the tensors found in the
        actions logged by PlaceHolders. Record those actions in
        plan.actions

        Args:
            args: Input arguments to run the plan
        """

        # Enable tracing
        self.toggle_tracing(True)
        self.is_building = True

        if trace_autograd:
            # Wrap arguments that require gradients with AutogradTensor,
            # to be able to trace autograd operations
            args = tuple(
                AutogradTensor().on(arg, wrap=False) if
                isinstance(arg, FrameworkTensor) and arg.requires_grad else arg
                for arg in args)
            # Add Placeholder after AutogradTensor in the chain
            # so that all operations that happen inside AutogradTensor are recorded by Placeholder
            args_placeholders = tuple(
                PlaceHolder.insert(arg,
                                   AutogradTensor,
                                   owner=sy.local_worker,
                                   role=self.role,
                                   tracing=True) for arg in args)
        else:
            # Add Placeholder on top of each arg
            args = args_placeholders = tuple(
                PlaceHolder.create_from(
                    arg, owner=sy.local_worker, role=self.role, tracing=True)
                for arg in args)

        # Add state to args if needed
        if self.include_state:
            args += (self.state, )

        with trace(framework_packages["torch"], self.role,
                   self.owner) as wrapped_torch:
            # Look for framework kwargs
            framework_kwargs = {}
            forward_args = inspect.getfullargspec(self.forward).args
            if "torch" in forward_args:
                framework_kwargs["torch"] = wrapped_torch

            results = self.forward(*args, **framework_kwargs)

        # Disable tracing
        self.toggle_tracing(False)
        self.is_building = False

        # Register inputs in role
        self.role.register_inputs(args_placeholders)

        # Register outputs in role
        if isinstance(results, (tuple, list)):
            results_placeholders = tuple(
                PlaceHolder.extract(result) for result in results)
        else:
            results_placeholders = PlaceHolder.extract(results)
        self.role.register_outputs(results_placeholders)

        self.is_built = True

        # Build registered translations
        for translator in Plan._build_translators:
            try:
                self.add_translation(translator)
                self.translations.append(translator)
            except:
                warnings.warn(
                    f"Failed to translate Plan with {translator.__name__}")

        return results