예제 #1
0
def reverse_model(
    model: Model,
    reverse_mappings,  # TODO: type annotate reverse_mappings
    default_reverse_mapping: Optional[Callable] = None,
    head_mapping: Callable = None,
    stop_mapping_at_tensors: List[Tensor] = None,
    verbose: bool = False,
    return_all_reversed_tensors: bool = False,
    clip_all_reversed_tensors: Union[bool, Tuple[float, float]] = False,
    project_bottleneck_tensors: Union[bool, Tuple[float, float]] = False,
    execution_trace: Optional[Tuple[List[Layer],
                                    List[Tuple[Layer, List[Tensor],
                                               List[Tensor]]],
                                    List[Tensor]]] = None,
    reapply_on_copied_layers: bool = False,
) -> Union[List[Tensor], Tuple[List[Tensor], Dict[Tensor, ReverseTensorDict]]]:
    """
    Reverses a Keras model based on the given reverse functions.
    It returns the reverted tensors for the according model inputs.

    :param model: A Keras model.
    :param reverse_mappings: Either a callable that matches layers to
      mappings or a dictionary with layers as keys and mappings as values.
      Allowed as mapping forms are:
          * A function of form (A) f(Xs, Ys, reversed_Ys, reverse_state).
          * A function of form f(B) f(layer, reverse_state) that returns
            a function of form (A).
          * A :class:`ReverseMappingBase` subclass.
    :param default_reverse_mapping: A function that reverses layers for
      which no mapping was given by param "reverse_mappings".
    :param head_mapping: Map output tensors to new values before passing
      them into the reverted network.
    :param stop_mapping_at_tensors: Tensors at which to stop the mapping.
      Similar to stop_gradient parameters for gradient computation.
    :param verbose: Print what's going on.
    :param return_all_reversed_tensors: Return all reverted tensors in addition
      to reverted model input tensors.
    :param clip_all_reversed_tensors: Clip each reverted tensor. False or tuple
      with min/max value.
    :param project_bottleneck_tensors: Project bottleneck layers in the
      reverting process into a given value range. False, True or (a, b) for
      projection range.
    :param reapply_on_copied_layers: When a model execution needs to
      linearized and copy layers before reapplying them. See
      :func:`trace_model_execution`.
    """

    # Set default values ######################################################
    if stop_mapping_at_tensors is None:
        stop_mapping_at_tensors = []

    if head_mapping is None:

        def head_mapping(X):
            return X

    if not callable(reverse_mappings):
        # not callable, assume a dict that maps from layer to mapping
        reverse_mapping_data = reverse_mappings

        def reverse_mappings(layer):
            try:
                return reverse_mapping_data[type(layer)]
            except KeyError:
                return None

    if clip_all_reversed_tensors is True:
        raise NotImplementedError(
            "Keyword argument `clip_all_reversed_tensors` ",
            "expected to be `False` or tuple with min/max values.",
        )

    def _print(s):
        if verbose is True:
            print(s)

    # Initialize structure that keeps track of reversed tensors
    # maps tensor to reverse tensor and additional node information
    reversed_tensors: Dict[Tensor, ReverseTensorDict]
    reversed_tensors = {}

    bottleneck_tensors = set()

    def add_reversed_tensors(nid, tensors_list, reversed_tensors_list) -> None:
        def add_reversed_tensor(i, X: Tensor, reversed_X: Tensor) -> None:
            # Do not keep tensors that should stop the mapping.
            if X in stop_mapping_at_tensors:  # type: ignore
                return

            if X not in reversed_tensors:  # no duplicate entries for forward tensors
                reversed_tensors[X] = {
                    "id": (nid, i),
                    "tensor": reversed_X,
                    "tensors": None,
                    "final_tensor": None,
                }
            else:
                tmp = reversed_tensors[X]  # tmp modifies reversed_tensors!
                if tmp["tensor"] is not None:
                    if tmp["tensors"] is not None:
                        raise Exception(
                            "Wrong order, tensors already aggregated!")

                    tmp["tensors"] = [tmp["tensor"], reversed_X]
                    tmp["tensor"] = None
                else:
                    raise Exception(
                        "Error during reverse tensor aggeregation.")

        tmp = zip(tensors_list, reversed_tensors_list)
        for i, (X, reversed_X) in enumerate(tmp):
            add_reversed_tensor(i, X, reversed_X)

    def get_reversed_tensor(tensor: Tensor) -> Tensor:
        tmp: ReverseTensorDict
        tmp = reversed_tensors[tensor]

        if tmp["final_tensor"] is None:
            if tmp["tensor"] is None:
                final_tensor = keras.layers.Add()(tmp["tensors"])
            else:
                final_tensor = tmp["tensor"]

            if project_bottleneck_tensors is not False:
                if tensor in bottleneck_tensors:
                    project = ilayers.Project(project_bottleneck_tensors)
                    final_tensor = project(final_tensor)

            if isinstance(clip_all_reversed_tensors, tuple):
                clip = ilayers.Clip(*clip_all_reversed_tensors)
                final_tensor = clip(final_tensor)

            tmp["final_tensor"] = final_tensor

        return tmp["final_tensor"]

    # Reverse the model #######################################################
    _print("Reverse model: {}".format(model))

    # Create a list with nodes in reverse execution order.
    if execution_trace is None:
        execution_trace = trace_model_execution(
            model, reapply_on_copied_layers=reapply_on_copied_layers)
    layers, execution_list, outputs = execution_trace
    len_execution_list = len(execution_list)
    num_input_layers = len([
        _ for l, _, _ in execution_list
        if isinstance(l, keras.layers.InputLayer)
    ])
    len_execution_list_wo_inputs_layers = len_execution_list - num_input_layers
    reverse_execution_list = reversed(execution_list)

    # Initialize the reverse mapping functions.
    initialized_reverse_mappings: Dict[Layer,
                                       Callable]  # TODO: specify Callable
    initialized_reverse_mappings = {}
    for layer in layers:
        # A layer can be shared, i.e., applied several times.
        # Allow to share a ReverMappingBase for each layer instance
        # in order to reduce the overhead.

        meta_reverse_mapping = reverse_mappings(layer)
        if meta_reverse_mapping is None:
            reverse_mapping = default_reverse_mapping
        elif inspect.isclass(meta_reverse_mapping) and issubclass(
                meta_reverse_mapping, ReverseMappingBase):
            # Mapping is a class
            reverse_mapping_obj = meta_reverse_mapping(
                layer,
                {
                    "model": model,
                    "layer": layer,
                },
            )
            reverse_mapping = reverse_mapping_obj.apply
        else:

            def parameter_count(func):
                if hasattr(inspect, "signature"):
                    ret = len(inspect.signature(func).parameters)
                else:
                    spec = inspect.getargspec(func)
                    ret = len(spec.args)
                    if spec.varargs is not None:
                        ret += len(spec.varargs)
                    if spec.keywords is not None:
                        ret += len(spec.keywords)
                    if ret == 3:
                        # assume class function with self
                        ret -= 1
                return ret

            if (callable(meta_reverse_mapping)
                    and parameter_count(meta_reverse_mapping) == 2):
                # Function that returns mapping
                reverse_mapping = meta_reverse_mapping(
                    layer,
                    {
                        "model": model,
                        "layer": layer,
                    },
                )
            else:
                # Nothing meta here
                reverse_mapping = meta_reverse_mapping

        initialized_reverse_mappings[
            layer] = reverse_mapping  # type: ignore # TODO: add annotations

    if project_bottleneck_tensors:
        bottleneck_tensors.update(
            get_bottleneck_tensors(model.inputs, outputs, execution_list))

    # Initialize the reverse tensor mappings.
    add_reversed_tensors(-1, outputs, [head_mapping(tmp)
                                       for tmp in outputs])  # type: ignore

    # Follow the list and revert the graph.
    for _nid, (layer, Xs, Ys) in enumerate(reverse_execution_list):
        nid = len_execution_list_wo_inputs_layers - _nid - 1

        if isinstance(layer, keras.layers.InputLayer):
            # Special case. Do nothing.
            pass
        elif kchecks.is_network(layer):
            raise Exception("This is not supposed to happen!")
        else:
            Xs, Ys = iutils.to_list(Xs), iutils.to_list(Ys)
            if not all([ys in reversed_tensors for ys in Ys]):
                # This node is not part of our computational graph.
                # The (node-)world is bigger than this model.
                # Potentially this node is also not part of the
                # reversed tensor set because it depends on a tensor
                # that is listed in stop_mapping_at_tensors.
                continue
            reversed_Ys = [get_reversed_tensor(ys) for ys in Ys]
            local_stop_mapping_at_tensors = [
                x for x in Xs if x in stop_mapping_at_tensors
            ]

            _print("  [NID: {}] Reverse layer-node {}".format(nid, layer))
            reverse_mapping = initialized_reverse_mappings[layer]
            reversed_Xs = reverse_mapping(
                Xs,
                Ys,
                reversed_Ys,
                {
                    "nid": nid,
                    "model": model,
                    "layer": layer,
                    "stop_mapping_at_tensors": local_stop_mapping_at_tensors,
                },
            )
            reversed_Xs = iutils.to_list(reversed_Xs)
            add_reversed_tensors(nid, Xs, reversed_Xs)

    # Return requested values
    reversed_input_tensors = [
        get_reversed_tensor(tmp) for tmp in model.inputs
        if tmp not in stop_mapping_at_tensors
    ]
    if return_all_reversed_tensors is True:
        return reversed_input_tensors, reversed_tensors
    else:
        return reversed_input_tensors
예제 #2
0
 def collect_layers(container: Model) -> None:
     for layer in container.layers:
         assert layer not in layers
         layers.append(layer)
         if kchecks.is_network(layer):
             collect_layers(layer)
예제 #3
0
def trace_model_execution(
    model: Model,
    reapply_on_copied_layers: bool = False
) -> Tuple[List[Layer], List[Tuple[Layer, List[Tensor], List[Tensor]]],
           List[Tensor]]:
    """
    Trace and linearize excecution of a model and it's possible containers.
    Return a triple with all layers, a list with a linearized execution
    with (layer, input_tensors, output_tensors), and, possible regenerated,
    outputs of the exectution.

    :param model: A kera model.
    :param reapply_on_copied_layers: If the execution needs to be linearized,
      reapply with copied layers. Might be slow. Prevents changes of the
      original layer's node lists.
    """

    # Get all layers in model.
    layers: List[Layer] = get_model_layers(model)

    # Check if some layers are containers.
    # Ignoring the outermost container, i.e. the passed model.
    contains_container: bool = any([((l is not model)
                                     and kchecks.is_network(l))
                                    for l in layers])

    outputs: List[Tensor]

    # If so rebuild the graph, otherwise recycle computations,
    # and create executed node list. (Keep track of paths?)
    if contains_container is True:
        # When containers/models are used as layers, then layers
        # inside the container/model do not keep track of nodes.
        # This makes it impossible to iterate of the nodes list and
        # recover the input output tensors. (see else clause)
        #
        # To recover the computational graph we need to re-apply it.
        # This implies that the tensors-object we use for the forward
        # pass are different to the passed model. This it not the case
        # for the else clause.
        #
        # Note that reapplying the model does only change the inbound
        # and outbound nodes of the model itself. We copy the model
        # so the passed model should not be affected from the
        # reapplication.
        executed_nodes: List[Tuple[Layer, List[Tensor], List[Tensor]]] = []

        # Monkeypatch the call function in all the used layer classes.
        monkey_patches: List[Tuple[Layer, Callable]] = [(layer, layer.call)
                                                        for layer in layers]
        try:

            def patch(self, method: Callable):
                if hasattr(method, "__patched__") is True:
                    raise Exception(
                        "Should not happen as we patch objects, not classes.")

                def f(*args, **kwargs):
                    input_tensors = args[0]
                    output_tensors = method(*args, **kwargs)
                    executed_nodes.append(
                        (self, input_tensors, output_tensors))
                    return output_tensors

                f.__patched__ = True  # type: ignore
                return f

            # Apply the patches.
            for layer in layers:
                layer.call = patch(layer, layer.call)

            # Trigger reapplication of model.
            model_copy: Model = keras.models.Model(inputs=model.inputs,
                                                   outputs=model.outputs)
            outputs = iutils.to_list(model_copy(model.inputs))
        finally:
            # Revert the monkey patches
            for layer, old_method in monkey_patches:
                layer.call = old_method

        # Now we have the problem that all the tensors
        # do not have a keras_history attribute as they are not part
        # of any node. Apply the flat model to get it.

        tensor_mapping: Dict[Tensor,
                             Tensor] = {tmp: tmp
                                        for tmp in model.inputs}
        layer_mapping: Dict[Layer, Layer]
        new_executed_nodes: List[Tuple[Layer, List[Tensor], List[Tensor]]] = []

        if reapply_on_copied_layers is True:
            layer_mapping = {layer: copy_layer(layer) for layer in layers}
        else:
            layer_mapping = {layer: layer for layer in layers}

        for layer, Xs, Ys in executed_nodes:
            layer = layer_mapping[layer]
            Xs, Ys = iutils.to_list(Xs), iutils.to_list(Ys)

            if isinstance(layer, keras.layers.InputLayer):
                # Special case. Do nothing.
                new_Xs, new_Ys = Xs, Ys
            else:
                new_Xs = [tensor_mapping[x] for x in Xs]
                new_Ys = iutils.to_list(kapply(layer, new_Xs))

            # Update values of Ys in tensor_mapping with new_Ys
            tensor_mapping.update({k: v for k, v in zip(Ys, new_Ys)})
            new_executed_nodes.append((layer, new_Xs, new_Ys))

        layers = [layer_mapping[layer] for layer in layers]
        outputs = [tensor_mapping[x] for x in outputs]
        executed_nodes = new_executed_nodes
    else:
        # Easy and safe way.
        reverse_executed_nodes: List[Tuple[
            Layer, List[Tensor], List[Tensor]]] = [
                (node.outbound_layer, node.input_tensors, node.output_tensors)
                for depth in sorted(model._nodes_by_depth.keys())
                for node in model._nodes_by_depth[depth]
            ]
        outputs = model.outputs

        executed_nodes = list(reversed(reverse_executed_nodes))

    # `executed_nodes` potentially contains nodes that are not part
    # of the final execution graph.
    # E.g. if a layer was also applied outside of the model. Then its
    # node list contains nodes that do not contribute to the model's output.
    # Those nodes are filtered here.
    used_as_input = [x for x in outputs]
    tmp = []
    for l, Xs, Ys in reversed(executed_nodes):
        if all([y in used_as_input for y in Ys]):
            used_as_input += Xs
            tmp.append((l, Xs, Ys))
    executed_nodes = list(reversed(tmp))

    return layers, executed_nodes, outputs
예제 #4
0
    def _create_computers(self):
        """
        Creates pattern objects and Keras models that are used to collect
        statistics and compute patterns.

        We compute the patterns by first computing statistics within
        the Keras framework, which are then used to compute the patterns.

        This is based on a workaround. We connect the stats computation
        via dummy outputs to a model's output and then iterate over the
        dataset to compute statistics.
        """
        # Create a broadcasting function that is used to connect
        # the dummy outputs.
        # Broadcaster has shape (mini_batch_size, 1)
        reduce_axes = list(range(len(K.int_shape(self.model.inputs[0]))))[1:]
        dummy_broadcaster = ilayers.Sum(axis=reduce_axes,
                                        keepdims=True)(self.model.inputs[0])

        def broadcast(x):
            return ilayers.Broadcast()([dummy_broadcaster, x])

        # Collect all tensors that are part of a model's execution.
        layers, execution_list, _ = kgraph.trace_model_execution(self.model)
        model_tensors = set()
        for _, input_tensors, output_tensors in execution_list:
            for t in input_tensors + output_tensors:
                model_tensors.add(t)

        # Create pattern instances and collect the dummy outputs.
        self._pattern_instances = {k: [] for k in self.pattern_types}
        computer_outputs = []
        for _layer_id, layer in enumerate(layers):
            # This does not work with containers!
            # They should be replaced by trace_model_execution.
            if kchecks.is_network(layer):
                raise Exception("Network in network is not suppored!")
            for pattern_type, clazz in six.iteritems(self.pattern_types):
                pinstance = clazz(
                    self.model,
                    layer,
                    model_tensors=model_tensors,
                    execution_list=execution_list,
                )
                if pinstance.has_pattern() is False:
                    continue
                self._pattern_instances[pattern_type].append(pinstance)
                dummy_output = pinstance.get_stats_from_batch()
                # Broadcast dummy_output to right shape.
                computer_outputs += iutils.to_list(broadcast(dummy_output))

        # Now we create one or more Keras models to train the patterns.
        self._n_computer_outputs = len(computer_outputs)
        if self.compute_layers_in_parallel is True:
            self._computers = [
                keras.models.Model(inputs=self.model.inputs,
                                   outputs=computer_outputs)
            ]
        else:
            self._computers = [
                keras.models.Model(inputs=self.model.inputs,
                                   outputs=computer_output)
                for computer_output in computer_outputs
            ]

        # Distribute computation on more gpus.
        if self.gpus is not None and self.gpus > 1:
            raise NotImplementedError("Not supported yet.")