Ejemplo n.º 1
0
    def test_copying_a_component(self):
        # Flatten a simple 2x2 FloatBox to (4,).
        space = FloatBox(shape=(2, 2), add_batch_rank=False)

        flatten_orig = ReShape(flatten=True, scope="A")
        flatten_copy = flatten_orig.copy(scope="B")
        container = Component(flatten_orig, flatten_copy)

        @rlgraph_api(component=container)
        def flatten1(self, input_):
            return self.sub_components["A"].apply(input_)

        @rlgraph_api(component=container)
        def flatten2(self, input_):
            return self.sub_components["B"].apply(input_)

        test = ComponentTest(component=container,
                             input_spaces=dict(input_=space))

        input_ = dict(input1=np.array([[0.5, 2.0], [1.0, 2.0]]),
                      input2=np.array([[1.0, 2.0], [3.0, 4.0]]))
        expected = dict(output1=np.array([0.5, 2.0, 1.0, 2.0]),
                        output2=np.array([1.0, 2.0, 3.0, 4.0]))
        for i in range_(1, 3):
            test.test(("flatten" + str(i), input_["input" + str(i)]),
                      expected_outputs=expected["output" + str(i)])
Ejemplo n.º 2
0
    def __init__(self, *sub_components, **kwargs):
        """
        Args:
            sub_components (Union[Component,List[Component]]): The sub-components to add to the Stack and connect
                to each other.

        Keyword Args:
            api_methods (List[Union[str,Tuple[str,str],dict]]): A list of strings of API-methods names to connect
                through the stack.
                Defaults to {"apply"}. All sub-Components must implement all API-methods in this set.
                Alternatively, this set may contain tuples (1st item is the final Stack's API method name, 2nd item
                is the name of the API-methods of the sub-Components to connect through).
                E.g. api_methods={("stack_run", "run")}. This will create "stack_run" for the Stack, which will call
                - one by one - all the "run" methods of the sub-Components.
                Alternatively, this set may contain spec-dicts with keys:
                `api` (exposed final API-method name), `component_api` (sub-Components API-method names to connect
                through), `function` (the custom API-function to use), `fold_time_rank` (whether to fold a time
                rank into a batch rank at the beginning), `unfold_time_rank` (whether to unfold the time rank
                at the end).

                Connecting always works by first calling the first sub-Component's API-method, then - with the
                result - calling the second sub-Component's API-method, etc..
                This is done for all API-methods in the given set, plus - optionally - time rank folding and unfolding
                at the beginning and/or end.
        """
        self.api_methods_options = kwargs.pop("api_methods", ["apply"])
        super(Stack, self).__init__(*sub_components,
                                    scope=kwargs.pop("scope", "stack"),
                                    **kwargs)

        self.num_allowed_inputs = None
        self.num_allowed_returns = None

        self.map_api_to_sub_components_api = dict()

        self.folder = ReShape(fold_time_rank=True, scope="time-rank-folder_")
        self.unfolder = ReShape(unfold_time_rank=True,
                                scope="time-rank-unfolder_")

        self.add_components(self.folder, self.unfolder)

        self._build_stack(self.api_methods_options)
Ejemplo n.º 3
0
class Stack(Component):
    """
    A component container stack that incorporates one or more sub-components some of whose API-methods
    (default: only `apply`) are automatically connected with each other (in the sequence the sub-Components are given
    in the c'tor), resulting in an API of the Stack.
    All sub-components' API-methods need to match in the number of input and output values. E.g. the third
    sub-component's api-metehod's number of return values has to match the forth sub-component's api-method's number of
    input parameters.
    """
    def __init__(self, *sub_components, **kwargs):
        """
        Args:
            sub_components (Union[Component,List[Component]]): The sub-components to add to the Stack and connect
                to each other.

        Keyword Args:
            api_methods (List[Union[str,Tuple[str,str],dict]]): A list of strings of API-methods names to connect
                through the stack.
                Defaults to {"apply"}. All sub-Components must implement all API-methods in this set.
                Alternatively, this set may contain tuples (1st item is the final Stack's API method name, 2nd item
                is the name of the API-methods of the sub-Components to connect through).
                E.g. api_methods={("stack_run", "run")}. This will create "stack_run" for the Stack, which will call
                - one by one - all the "run" methods of the sub-Components.
                Alternatively, this set may contain spec-dicts with keys:
                `api` (exposed final API-method name), `component_api` (sub-Components API-method names to connect
                through), `function` (the custom API-function to use), `fold_time_rank` (whether to fold a time
                rank into a batch rank at the beginning), `unfold_time_rank` (whether to unfold the time rank
                at the end).

                Connecting always works by first calling the first sub-Component's API-method, then - with the
                result - calling the second sub-Component's API-method, etc..
                This is done for all API-methods in the given set, plus - optionally - time rank folding and unfolding
                at the beginning and/or end.
        """
        self.api_methods_options = kwargs.pop("api_methods", ["apply"])
        super(Stack, self).__init__(*sub_components,
                                    scope=kwargs.pop("scope", "stack"),
                                    **kwargs)

        self.num_allowed_inputs = None
        self.num_allowed_returns = None

        self.map_api_to_sub_components_api = dict()

        self.folder = ReShape(fold_time_rank=True, scope="time-rank-folder_")
        self.unfolder = ReShape(unfold_time_rank=True,
                                scope="time-rank-unfolder_")

        self.add_components(self.folder, self.unfolder)

        self._build_stack(self.api_methods_options)

    def _build_stack(self, api_methods):
        """
        For each api-method in set `api_methods`, automatically create this Stack's own API-method by connecting
        through all sub-Component's API-methods. This is skipped if this Stack already has a custom API-method
        by that name.

        Args:
            api_methods (List[Union[str,Tuple[str,str],dict]]): See ctor kwargs.
        """
        # Loop through the API-method set and register each one.
        for api_method_spec in api_methods:
            function_to_use = None
            fold_time_rank = False
            unfold_time_rank = False

            # Detailed spec-dict given.
            if isinstance(api_method_spec, dict):
                stack_api_method_name = api_method_spec["api"]
                component_api_method_name = api_method_spec.get(
                    "component_api", stack_api_method_name)
                function_to_use = api_method_spec.get("function")
                fold_time_rank = api_method_spec.get("fold_time_rank", False)
                unfold_time_rank = api_method_spec.get("unfold_time_rank",
                                                       False)
            # API-method of sub-Components and this Stack should have different names.
            elif isinstance(api_method_spec, tuple):
                # Custom method given, use that instead of creating one automatically.
                if callable(api_method_spec[1]):
                    stack_api_method_name = component_api_method_name = api_method_spec[
                        0]
                    function_to_use = api_method_spec[1]
                else:
                    stack_api_method_name, component_api_method_name = api_method_spec[
                        0], api_method_spec[1]
            # API-method of sub-Components and this Stack should have the same name.
            else:
                stack_api_method_name = component_api_method_name = api_method_spec

            self.map_api_to_sub_components_api[
                stack_api_method_name] = component_api_method_name

            # API-method for this Stack does not exist yet -> Automatically create it.
            if not hasattr(self, stack_api_method_name):
                # Custom API-method is given (w/o decorator) -> Call the decorator directly here to register it.
                if function_to_use is not None:
                    rlgraph_api(api_method=function_to_use,
                                component=self,
                                name=stack_api_method_name)
                # No API-method given -> Create auto-API-method and set it up through decorator.
                else:
                    self.build_auto_api_method(
                        stack_api_method_name,
                        component_api_method_name,
                        fold_time_rank=fold_time_rank,
                        unfold_time_rank=unfold_time_rank)

    def build_auto_api_method(self,
                              stack_api_method_name,
                              sub_components_api_method_name,
                              fold_time_rank=False,
                              unfold_time_rank=False,
                              ok_to_overwrite=False):
        """
        Creates and registers an auto-API method for this stack.

        Args:
            stack_api_method_name (str): The name for the (exposed) API-method of the Stack.

            sub_components_api_method_name (str): The name of the single sub-components' API-methods to call one after
                another.

            ok_to_overwrite (Optional[bool]): Set to True if we know we are overwriting
        """
        @rlgraph_api(name=stack_api_method_name,
                     component=self,
                     ok_to_overwrite=ok_to_overwrite)
        def method(self_, *inputs, **kwargs):
            # Fold time rank? For now only support 1st arg folding/unfolding.
            original_input = inputs[0]
            if fold_time_rank is True:
                args_ = tuple([self.folder.apply(original_input)] +
                              list(inputs[1:]))
            else:
                # TODO: If only unfolding: Assume for now that 2nd input is the original one (so we can infer
                # TODO: batch/time dims).
                if unfold_time_rank is True:
                    assert len(inputs) >= 2, \
                        "ERROR: In Stack: If unfolding w/o folding, second arg must be the original input!"
                    original_input = inputs[1]
                    args_ = tuple([inputs[0]] + list(inputs[2:]))
                else:
                    args_ = inputs
            kwargs_ = kwargs

            for i, sub_component in enumerate(
                    self_.sub_components.values()):  # type: Component
                if sub_component.scope in [
                        "time-rank-folder_", "time-rank-unfolder_"
                ]:
                    continue
                # TODO: python-Components: For now, we call each preprocessor's graph_fn
                #  directly (assuming that inputs are not ContainerSpaces).
                if self_.backend == "python" or get_backend() == "python":
                    graph_fn = getattr(
                        sub_component,
                        "_graph_fn_" + sub_components_api_method_name)
                    # if sub_component.api_methods[components_api_method_name].add_auto_key_as_first_param:
                    #    results = graph_fn("", *args_)  # TODO: kwargs??
                    # else:
                    results = graph_fn(*args_)
                elif get_backend() == "pytorch":
                    # Do NOT convert to tuple, has to be in unpacked again immediately.n
                    results = getattr(
                        sub_component,
                        sub_components_api_method_name)(*force_list(args_))
                else:  # if get_backend() == "tf":
                    results = getattr(sub_component,
                                      sub_components_api_method_name)(
                                          *args_, **kwargs_)

                # Recycle args_, kwargs_ for reuse in next sub-Component's API-method call.
                if isinstance(results, dict):
                    args_ = ()
                    kwargs_ = results
                else:
                    args_ = force_tuple(results)
                    kwargs_ = {}

            if args_ == ():
                # Unfold time rank? For now only support 1st arg folding/unfolding.
                if unfold_time_rank is True:
                    assert len(kwargs_) == 1,\
                        "ERROR: time-rank-unfolding not supported for more than one NN-return value!"
                    key = next(iter(kwargs_))
                    kwargs_ = {
                        key: self.unfolder.apply(kwargs_[key], original_input)
                    }
                return kwargs_
            else:
                # Unfold time rank? For now only support 1st arg folding/unfolding.
                if unfold_time_rank is True:
                    assert len(args_) == 1,\
                        "ERROR: time-rank-unfolding not supported for more than one NN-return value!"
                    args_ = tuple(
                        [self.unfolder.apply(args_[0], original_input)] +
                        list(args_[1 if fold_time_rank is True else 2:]))
                if len(args_) == 1:
                    return args_[0]
                else:
                    return args_

    @classmethod
    def from_spec(cls, spec=None, **kwargs):
        spec_deepcopy = copy.deepcopy(spec)
        if isinstance(spec, dict):
            kwargs["_args"] = list(spec_deepcopy.pop("layers", []))
        elif isinstance(spec, (tuple, list)):
            kwargs["_args"] = spec_deepcopy
            spec_deepcopy = None
        return super(Stack, cls).from_spec(spec_deepcopy, **kwargs)