コード例 #1
0
def tile(router, R):
    # Sum the relevances received from upper layers
    R = lrp_util.sum_relevances(R)

    # Get the current operation, i.e. the tile operation we are currently handling
    current_operation = router.get_current_operation()

    # Get the input to the tiling operation and find the shape of it
    input_to_current_operation = current_operation.inputs[0]
    input_shape = tf.shape(input_to_current_operation)

    # Get the size of the 'predictions_per_sample' dimension from R
    R_shape = tf.shape(R)
    predictions_per_sample = R_shape[1]
    # Reshape to a list so it can be used in the concat below
    predictions_per_sample = tf.reshape(predictions_per_sample, (1, ))

    # Get the tensor that tells how many times the input has been duplicated for each dimension of the input
    copies_per_dimension = current_operation.inputs[1]

    # Get the number of dimensions of the input
    rank_input = tf.size(copies_per_dimension)

    # Transpose R from shape (batch_size, predictions_per_sample, ....) to shape
    # (predictions_per_sample, batch_size, ...) since the predictions_per_sample dimensions is left untouched in
    # the processing below
    remaining_axes = tf.range(2, rank_input + 1)
    perm = tf.concat([[1, 0], remaining_axes], 0)
    R = tf.transpose(R, perm)

    # Reshape R to shape (copies_dim_0, input_size_dim_0, ... copies_dim_(r-1), input_size_dim(r-1))
    double_rank = rank_input * 2
    zipped_dims = tf.reshape(tf.transpose([copies_per_dimension, input_shape]),
                             (double_rank, ))
    zipped_dims = tf.concat([predictions_per_sample, zipped_dims], 0)
    R = tf.reshape(R, zipped_dims)

    # Transpose R to shape (input_size_dim_0, copies_dim_0 ... input_size_dim(r-1), copies_dim_(r-1))
    perm1 = tf.range(2, double_rank + 1, 2)
    perm2 = tf.range(1, double_rank + 1, 2)
    zipped_perm = tf.reshape(tf.transpose([perm1, perm2]), (double_rank, ))
    zipped_perm = tf.concat([[0], zipped_perm], 0)
    R = tf.transpose(R, zipped_perm)

    # Reduce sum for R over dimensions 'input_size_dim_0', ... 'input_size_dim(r-1)'
    R_new = tf.reduce_sum(R, perm1)

    # Transpose R back from shape (predictions_per_sample, batch_size ....) to shape
    # (batch_size, predictions_per_sample, ...)
    remaining_axes = tf.range(2, rank_input + 1)
    perm = tf.concat([[1, 0], remaining_axes], 0)
    R_new = tf.transpose(R_new, perm)

    # Mark the tiling operation as handled
    router.mark_operation_handled(current_operation)

    # Forward the relevance to input to the tile operation
    router.forward_relevance_to_operation(R_new, current_operation,
                                          input_to_current_operation.op)
コード例 #2
0
ファイル: sparse_reorder_lrp.py プロジェクト: fhvilshoj/lrp
def sparse_reorder(router, R):
    """
    Handeling softmax layers by passing the relevance along to the input
    :param router: the router object to report changes to
    :param R: the list of tensors containing the relevances from the upper layers
    """
    # Sum the potentially multiple relevances from the upper layers
    R = lrp_util.sum_relevances(R)

    # Get the current operation
    current_operation = router.get_current_operation()

    # Report handled operations
    router.mark_operation_handled(current_operation)

    # Forward the calculated relevance to the input of the convolution
    for input in current_operation.inputs:
        router.forward_relevance_to_operation(R, current_operation, input.op)
コード例 #3
0
ファイル: forward_lrp.py プロジェクト: fhvilshoj/lrp
def forward(router, R):
    """
    Handeling all nonlinearities (Relu, Sigmoid, Tanh) by passing the relevance along
    :param router: the router object to report changes to
    :param R: the list of tensors containing the relevances from the upper layers
    """
    # Sum the potentially multiple relevances from the upper layers
    R = lrp_util.sum_relevances(R)

    # Get the current operation
    current_operation = router.get_current_operation()

    # Report handled operations
    router.mark_operation_handled(current_operation)

    # Forward the calculated relevance to the input of the convolution
    router.forward_relevance_to_operation(R, current_operation,
                                          current_operation.inputs[0].op)
コード例 #4
0
ファイル: slicing_lrp.py プロジェクト: fhvilshoj/lrp
def slicing(router, R):
    # Sum the relevances
    R = lrp_util.sum_relevances(R)

    # Get the current operation
    current_operation = router.get_current_operation()

    # Get the input to the slicing operation
    input_to_slicing_operation = current_operation.inputs[0]

    # Cut off the batch size and maybe predictions pr. sample if exists from the tensor shapes
    # since we never want to pad the batch dimension
    free_dimensions = 2 - (tf.rank(R) - tf.rank(input_to_slicing_operation))
    def _relevant_dims(tensor):
        return tensor[free_dimensions:]

    # Get the shape of the input to the slicing operation
    input_shape = _relevant_dims(tf.shape(input_to_slicing_operation))

    # Find starting point of the slice operation
    starting_point = _relevant_dims(current_operation.inputs[1])

    # Find size of the slice operation
    size_of_slice = _relevant_dims(current_operation.inputs[2])

    # Find the number of zeros to insert after the relevances
    end_zeros = input_shape - (starting_point + size_of_slice)

    # For each axis, insert 'starting_point' zeros before the relevances and
    # 'size_of_split'-('starting_point' + 'size_of_split') zeros after the relevances
    padding = tf.transpose(tf.stack([starting_point, end_zeros]))

    # Never use any padding for the first two dimensions, since these are always batch_size, predictions_per_sample
    # and should stay constant all the way through the framework
    batch_and_sample_padding = tf.zeros((2, 2), dtype=tf.int32)
    padding = tf.concat([batch_and_sample_padding, padding], axis=0)

    R_new = tf.pad(R, padding)

    # Mark operations as handled
    router.mark_operation_handled(current_operation)

    # Forward the calculated relevances
    router.forward_relevance_to_operation(R_new, current_operation, input_to_slicing_operation.op)
コード例 #5
0
ファイル: lstm_context_handler.py プロジェクト: fhvilshoj/lrp
    def handle_context(self, context):
        path_ = context[CONTEXT_PATH]

        # Get the relevances to move through the LSTM
        R = self.router.get_relevance_for_operation(path_[0])

        # Sum the potentially multiple relevances from the upper layers
        R = lrp_util.sum_relevances(R)

        # Get the path containing all operations in the LSTM
        path = context[CONTEXT_PATH]

        # Get the extra information related to the LSTM context
        extra_context_information = context[EXTRA_CONTEXT_INFORMATION]

        # Get the transpose operation that marks the beginning of the LSTM
        transpose_operation = extra_context_information[
            LSTM_BEGIN_TRANSPOSE_OPERATION]

        # Get the operation that produces the input to the LSTM (i.e. the operation right before
        # the transpose that marks the start of the LSTM)
        input_operation = extra_context_information[LSTM_INPUT_OPERATION]

        # Get the tensor that is the input to the LSTM (i.e. the input to the transpose operation
        # that marks the start of the LSTM)
        LSTM_input = transpose_operation.inputs[0]

        # TODO use configuration to set alpha_beta or epsilon rule
        # Calculate the relevances to distribute to the lower layers
        lstm_config = self.get_configuration(LAYER.LSTM)
        R_new = lstm(lstm_config, path, R, LSTM_input)

        # TODO this call should be done in the LSTM. Not in the context handler.
        # Mark all operations belonging to the LSTM as "handled"
        for op in path:
            self.router.mark_operation_handled(op)

        # TODO this call should be done in the LSTM. Not in the context handler.
        # Forward the relevances to the lower layers
        self.router.forward_relevance_to_operation(R_new, transpose_operation,
                                                   input_operation)
コード例 #6
0
ファイル: sparse_reshape_lrp.py プロジェクト: fhvilshoj/lrp
def sparse_reshape(router, R):
    # Sum the potentially multiple relevances from the upper layers
    R = lrp_util.sum_relevances(R)

    # Get the current operation (i.e. the sparse reshape operation we are currently taking care of)
    current_operation = router.get_current_operation()

    # Get the shape of the input to the sparse reshape operation
    input_shape = current_operation.inputs[1]

    # Split the shape of the input into 'batch_size' and everything else
    batch_size, input_shape_without_batch_size = tf.split(
        input_shape, [1, -1], 0)

    # Get the shape of the relevances
    relevances_shape = tf.shape(R)

    # Find the size of the predictions_per_sample dimension
    predictions_per_sample = relevances_shape[1]

    # Cast 'predictions_per_sample' to int64 to be able to concatenate with the dimensions from the sparse
    # tensor which are int64
    predictions_per_sample = tf.cast(predictions_per_sample, tf.int64)

    # Concatenate the dimensions to get the new shape of the relevances
    relevances_new_shape = tf.concat(
        [batch_size, [predictions_per_sample], input_shape_without_batch_size],
        0)

    # Reshape R to the same shape as the input to the reshaping operation
    R_reshaped = tf.sparse_reshape(R, relevances_new_shape)

    # Tell the router that we handled this operation
    router.mark_operation_handled(current_operation)

    # Forward relevance to the operation of the input to the current operation
    for input in current_operation.inputs:
        router.forward_relevance_to_operation(R_reshaped, current_operation,
                                              input.op)
コード例 #7
0
def shaping(router, R):
    # Sum the potentially multiple relevances from the upper layers
    R = lrp_util.sum_relevances(R)

    # Get the current operation (i.e. the shaping operation we are currently taking care of)
    current_operation = router.get_current_operation()

    # Get the input to the shaping operation
    input_to_current_operation = current_operation.inputs[0]

    # Get the shape of the input to the shaping operation
    input_shape = tf.shape(input_to_current_operation)

    # Split the shape of the input into 'batch_size' and everything else
    batch_size, input_shape_without_batch_size = tf.split(
        input_shape, [1, -1], 0)

    # Get the shape of the relevances
    relevances_shape = tf.shape(R)

    # Find the size of the predictions_per_sample dimension
    predictions_per_sample = relevances_shape[1]

    # Concatenate the dimensions to get the new shape of the relevances
    relevances_new_shape = tf.concat(
        [batch_size, [predictions_per_sample], input_shape_without_batch_size],
        0)

    # Reshape R to the same shape as the input to the reshaping operation that created the tensor
    # but leave the two first dimensions untouched since they are batch_size, predictions_per_sample
    # which stay constant all the way through the framework
    R_reshaped = tf.reshape(R, relevances_new_shape)

    # Tell the router that we handled this operation
    router.mark_operation_handled(current_operation)

    # Forward relevance to the operation of the input to the current operation
    router.forward_relevance_to_operation(R_reshaped, current_operation,
                                          input_to_current_operation.op)
コード例 #8
0
    def _lrp_routing(self):
        # Create context handler switch which is used to forward the responsibility of the
        # different types of contexts to the appropriate handlers
        context_switch = ContextHandlerSwitch(self)

        # Handle each context separately by routing the context through the context switch
        for idx, current_context in enumerate(self.contexts):
            self.final_context = len(self.contexts) == idx + 1
            context_switch.handle_context(current_context)

        # Sum the potentially multiple relevances calculated for the input
        final_input_relevances = lrp_util.sum_relevances(
            self.relevances[self._input.op._id])

        # If the starting point relevances were shape (batch_size, classes), remove the extra
        # predictions_per_sample dimension that was added to the starting point relevances
        if not self.starting_point_relevances_had_predictions_per_sample_dimension:
            # Check if the relevances are sparse, in which case we need to use tf's sparse reshape operation
            # to remove the extra dimension
            if isinstance(final_input_relevances, tf.SparseTensor):
                # Get the shape of the final relevances
                final_input_relevances_shape = tf.shape(final_input_relevances)
                # Extract the batch_size dimension
                batch_size = tf.slice(final_input_relevances_shape, [0], [1])
                # Extract all the dimensions after the predictions_per_sample dimension
                sample_dimensions = tf.slice(final_input_relevances_shape, [2],
                                             [-1])
                # Create the new shape of the relevances, i.e. the shape where the predictions_per_sample
                # has been removed
                final_input_relevances_new_shape = tf.concat(
                    [batch_size, sample_dimensions], 0)
                # Remove the predictions_per_sample dimension
                final_input_relevances = tf.sparse_reshape(
                    final_input_relevances, final_input_relevances_new_shape)
            # If the relevances are not sparse, i.e. they are dense, we can just squeeze the extra dimension
            else:
                final_input_relevances = tf.squeeze(final_input_relevances, 1)

        return final_input_relevances
コード例 #9
0
def concatenate(router, R):
    # Sum the potentially multiple relevances from the upper layers
    # Shape: (batch_size, ...)
    R = lrp_util.sum_relevances(R)

    # Get the current concatenate operation
    current_operation = router.get_current_operation()

    # Find axis that the concatenation was over
    axis = current_operation.inputs[-1]

    # Split relevances in same order. Start by initializing empty arrays to hold respectively the sizes that the
    # relevances shall be split in and the receivers of the relevances
    split_sizes = []
    input_operations = []
    # Run through the inputs to the current operation except the last (the last input is the "axis" input)
    for inp in current_operation.inputs[:-1]:
        # Add the operation to the array
        input_operations.append(inp.op)
        # Find the shape of the operation
        shape = tf.shape(inp)
        # Find and add the size of the input in the "axis" dimension
        split_sizes.append(shape[axis])

    # Adjust the axis to split over, since we in the lrp router have one extra dimension for
    # predictions_per_sample
    axis += 1

    # Split the relevances over the "axis" dimension according to the found split sizes
    R_splitted = tf.split(R, split_sizes, axis)

    # Tell the router that we handled the concatenate operation
    router.mark_operation_handled(current_operation)

    # Forward relevance to the operation of the input to the concatenate operation
    for input_index, relevance in enumerate(R_splitted):
        router.forward_relevance_to_operation(relevance, current_operation,
                                              input_operations[input_index])
コード例 #10
0
def convolutional(router, R):
    """
    Convolutional lrp
    :param router: the router object to report changes to
    :param R: the list of tensors containing the relevances from the upper layers
    """
    # Sum the potentially multiple relevances from the upper layers
    # Shape of R: (batch_size, predictions_per_sample, out_height, out_width, out_depth)
    R = lrp_util.sum_relevances(R)

    # Start by assuming the activation tensor is the output
    # of a convolution (i.e. not an addition with a bias)
    # Shape of current_tensor and convolution_tensor: (batch_size, out_height, out_width, out_depth)
    current_operation = router.get_current_operation()
    current_tensor = convolution_tensor = current_operation.outputs[0]

    # Remember that there was no bias
    with_bias = False

    bias_tensor = None
    # If the top operation is an addition (i.e. the above assumption
    # does not hold), move through the graph to find the output of the nearest convolution
    if current_operation.type in ['BiasAdd', 'Add']:
        # Shape of convolution_tensor: (batch_size, out_height, out_width, out_depth)
        convolution_tensor = lrp_util.find_first_tensor_from_type(
            current_tensor, 'Conv2D')
        bias_tensor = lrp_util.get_input_bias_from_add(current_tensor)
        with_bias = True

    # Find the inputs to the convolution
    (conv_input, filters) = convolution_tensor.op.inputs

    # Find the padding and strides that were used in the convolution
    padding = convolution_tensor.op.get_attr("padding").decode("UTF-8")
    strides = convolution_tensor.op.get_attr("strides")

    # Extract dimensions of the filters
    filter_sh = filters.get_shape().as_list()
    filter_height = filter_sh[0]
    filter_width = filter_sh[1]

    # Get shape of the input
    input_shape = tf.shape(conv_input)
    batch_size = input_shape[0]
    input_height = input_shape[1]
    input_width = input_shape[2]
    input_channels = input_shape[3]

    # Get the shape of the output of the convolution
    convolution_tensor_shape = tf.shape(convolution_tensor)
    output_height = convolution_tensor_shape[1]
    output_width = convolution_tensor_shape[2]
    output_channels = convolution_tensor_shape[3]

    # Extract every patch of the input (i.e. portion of the input that a filter looks at a
    # time), to get a tensor of shape
    # (batch_size, out_height, out_width, filter_height*filter_width*input_channels)
    image_patches = tf.extract_image_patches(
        conv_input, [1, filter_height, filter_width, 1], strides, [1, 1, 1, 1],
        padding)

    # Reshape patches to suit linear
    # shape: (batch_size * out_height * out_width, filter_height * filter_width * input_channels)
    linear_input = tf.reshape(image_patches,
                              (batch_size * output_height * output_width,
                               filter_height * filter_width * input_channels))

    # Reshape finters to suit linear
    linear_filters = tf.reshape(filters, (-1, output_channels))

    # Transpose relevances to suit linear
    # Shape goes from (batch_size, predictions_per_sample, out_height, out_width, out_channels)
    # to: (batch_size, out_height, out_width, predictions_per_sample, out_channels)
    R_shape = tf.shape(R)

    # Find the number of predictions per sample from R
    predictions_per_sample = R_shape[1]

    # Make transpose order (0, 2, .. , 1, last_dim)
    # This in necessary because for conv1d the output might have been expanded which
    # makes the output size partially unknown
    transpose_order = tf.concat(
        [[0],
         tf.range(2,
                  tf.size(R_shape) - 1), [1], [tf.size(R_shape) - 1]], 0)

    # Do the actual transpose
    linear_R = tf.transpose(R, transpose_order)

    # Reshape linear_R to have three dimensions
    linear_R = tf.reshape(linear_R, (batch_size * output_height * output_width,
                                     predictions_per_sample, output_channels))

    # Fetch configuration for linear_lrp
    config = router.get_configuration(LAYER.CONVOLUTIONAL)

    if config.type == RULE.ZB:
        # Reshape to convolution shape
        low = tf.reshape(config.low,
                         [1, input_height, input_width, input_channels])
        high = tf.reshape(config.high,
                          [1, input_height, input_width, input_channels])

        # Tile to batch size
        # Shapes: (batch_size, input_height, input_width, input_channels)
        low = tf.tile(low, [batch_size, 1, 1, 1])
        high = tf.tile(high, [batch_size, 1, 1, 1])

        # Extract image patches
        # Shapes: (batch_size, output_height, output_width, filter_height * filter_width * input_channels)
        low = tf.extract_image_patches(low,
                                       [1, filter_height, filter_width, 1],
                                       strides, [1, 1, 1, 1], padding)
        high = tf.extract_image_patches(high,
                                        [1, filter_height, filter_width, 1],
                                        strides, [1, 1, 1, 1], padding)

        # Reshape image patches for linear layer
        config.low = tf.reshape(
            low, (batch_size * output_height * output_width,
                  filter_height * filter_width * input_channels, 1))
        config.high = tf.reshape(
            high, (batch_size * output_height * output_width,
                   filter_height * filter_width * input_channels, 1))

    # Pass the responsibility to linear_lrp
    # Shape of linear_R_new:
    # (batch_size * out_height * out_width, predictions_per_sample, filter_height * filter_width * input_channels)
    linear_R_new = linear_with_config(linear_R,
                                      linear_input,
                                      linear_filters,
                                      config,
                                      bias=bias_tensor)

    # Shape back to be able to restitch
    linear_R_new = tf.reshape(
        linear_R_new,
        (batch_size, output_height, output_width, predictions_per_sample,
         filter_height * filter_width * input_channels))

    # Transpose back to be able to restitch
    # New shape:
    # (batch_size, predictions_per_sample, out_height, out_width, filter_height * filter_width * input_channels)
    linear_R_new = tf.transpose(linear_R_new, [0, 3, 1, 2, 4])

    # Gather batch_size and predictions_per_sample
    # New shape:
    # (batch_size * predictions_per_sample, out_height, out_width, filter_height * filter_width * input_channels)
    linear_R_new = tf.reshape(
        linear_R_new,
        (batch_size * predictions_per_sample, output_height, output_width,
         filter_height * filter_width * input_channels))

    # Restitch relevances to the input size
    R_new = lrp_util.patches_to_images(
        linear_R_new, batch_size * predictions_per_sample, input_height,
        input_width, input_channels, output_height, output_width,
        filter_height, filter_width, strides[1], strides[2], padding)

    # Reshape the calculated relevances from
    # (batch_size * predictions_per_sample, input_height, input_width, input_channels) to new shape:
    # (batch_size, predictions_per_sample, input_height, input_width, input_channels)
    R_new = tf.reshape(R_new, (batch_size, predictions_per_sample,
                               input_height, input_width, input_channels))

    # Report handled operations
    router.mark_operation_handled(current_tensor.op)
    router.mark_operation_handled(convolution_tensor.op)

    # In case of 1D convolution we need to skip the squeeze operation in
    # the path towards the input
    if with_bias and current_tensor.op.inputs[0].op.type == 'Squeeze':
        router.mark_operation_handled(current_tensor.op.inputs[0].op)

    # Forward the calculated relevance to the input of the convolution
    router.forward_relevance_to_operation(R_new, convolution_tensor.op,
                                          conv_input.op)
コード例 #11
0
def max_pooling(router, R):
    """
    Max pooling lrp
    :param router: the router object to report changes to
    :param R: the list of tensors containing the relevances from the upper layers
    """
    # Sum the potentially multiple relevances from the upper layers
    # Shape of R: (batch_size, predictions_per_sample, out_height, out_width, out_depth)
    R = lrp_util.sum_relevances(R)

    # Get current operation from the router
    current_operation = router.get_current_operation()

    # Get the output from the max pool operation
    # Shape of current_tensor: (batch_size, out_height, out_width, out_depth)
    current_tensor = current_operation.outputs[0]

    # Find the input to the max pooling
    # Shape of current_tensor: (batch_size, in_height, in_width, in_depth)
    max_pool_input = current_operation.inputs[0]

    # Find the padding and strides that were used in the max pooling
    padding = current_operation.get_attr("padding").decode("UTF-8")
    strides = current_operation.get_attr("strides")
    kernel_size = current_operation.get_attr("ksize")

    # Get shape of the input
    max_pooling_input_shape = tf.shape(max_pool_input)
    batch_size = max_pooling_input_shape[0]
    input_height = max_pooling_input_shape[1]
    input_width = max_pooling_input_shape[2]
    input_channels = max_pooling_input_shape[3]

    # (batch_size, input_height, input_width, input_channels) = max_pool_input.get_shape().as_list()

    # Get the shape of the output of the max pool operation
    current_tensor_shape = tf.shape(current_tensor)
    output_height = current_tensor_shape[1]
    output_width = current_tensor_shape[2]
    output_channels = current_tensor_shape[3]

    # Extract information about R for later reshapes
    R_shape = tf.shape(R)
    predictions_per_sample = R_shape[1]

    batch_size_times_predictions_per_sample = batch_size * predictions_per_sample

    # (_, output_height, output_width, output_channels) = current_tensor.get_shape().as_list()

    # Extract every patch of the input (i.e. portion of the input that the kernel looks at a time)
    # Shape of image_patches: (batch, out_height, out_width, kernel_height*kernel_width*input_channels)
    image_patches = tf.extract_image_patches(max_pool_input, kernel_size,
                                             strides, [1, 1, 1, 1], padding)

    # Reshape image patches to "small images" instead of lists
    # Shape of image_patches after reshape:
    # (batch_size, out_height, out_width, kernel_height, kernel_width, input_channels)
    image_patches_reshaped = tf.reshape(image_patches, [
        batch_size, output_height, output_width, kernel_size[1],
        kernel_size[2], input_channels
    ])

    def _winner_takes_all():
        image_patches_transposed = tf.transpose(image_patches_reshaped,
                                                [0, 1, 2, 5, 3, 4])
        image_patches_rt = tf.reshape(
            image_patches_transposed,
            (batch_size, output_height, output_width, input_channels, -1))
        image_patches_argmax = tf.argmax(image_patches_rt, axis=-1)
        # Shape (batch_size, out_height, out_width, input_channels, kernel_height * kernel_width)
        image_patches_one_hot = tf.one_hot(image_patches_argmax,
                                           kernel_size[1] * kernel_size[2])
        one_hot_transposed = tf.transpose(image_patches_one_hot,
                                          [0, 1, 2, 4, 3])
        one_hot_reshaped = tf.reshape(
            one_hot_transposed,
            (batch_size, output_height, output_width, kernel_size[1],
             kernel_size[2], input_channels))
        return one_hot_reshaped

    def _winners_take_all():
        # Find the largest elements in each patch and set all other entries to zero (to find z_ijk+'s)
        # Shape of max_elems: (batch_size, out_height, out_width, 1, 1, input_channels)
        max_elems = tf.reshape(
            current_tensor,
            (batch_size, output_height, output_width, 1, 1, input_channels))

        # Select maximum in each patch and set all others to zero
        # Shape of zs: (batch_size, out_height, out_width, kernel_height, kernel_width, input_channels)
        return tf.where(tf.equal(image_patches_reshaped, max_elems),
                        tf.ones_like(image_patches_reshaped),
                        tf.zeros_like(image_patches_reshaped))

    def _distribute_relevances():
        # Do nothing. This will distribute the relevance according to preactivations
        return image_patches_reshaped

    config = router.get_configuration(LAYER.MAX_POOLING)
    if config.type == RULE.WINNERS_TAKE_ALL:
        zs = _winners_take_all()
    elif config.type == RULE.WINNER_TAKES_ALL:
        zs = _winner_takes_all()
    else:
        zs = _distribute_relevances()

    # Count how many zijs had the maximum value for each patch
    denominator = tf.reduce_sum(zs, axis=[3, 4], keep_dims=True)
    denominator = tf.where(tf.equal(denominator, 0), tf.ones_like(denominator),
                           denominator)

    # Find the contribution of each feature in the input to the activations,
    # i.e. the ratio between the z_ijk's and the z_jk's
    # Shape of fractions: (batch_size, out_height, out_width, kernel_height, kernel_width, input_channels)
    fractions = zs / denominator

    # Add the predictions_per_sample dimension to be able to broadcast fractions over the different
    # predictions for the same sample
    # Shape after expand_dims:
    # (batch_size, predictions_per_sample=1, out_height, out_width, kernel_height, kernel_width, input_channels)
    fractions = tf.expand_dims(fractions, 1)

    # Put the relevance for each patch in the dimension that corresponds to the "input_channel" dimension
    # of the fractions
    # Shape of R after reshape: (batch_size, predictions_per_sample, out_height, out_width, 1, 1, out_channels)
    R_distributed = tf.reshape(R, [
        batch_size, predictions_per_sample, output_height, output_width, 1, 1,
        output_channels
    ])

    # Distribute the relevance onto athe fractions
    # Shape of new relevances: (batch_size, predictions_per_sample, out_height, out_width, kernel_height, kernel_width, input_channels)
    relevances = fractions * R_distributed

    # Put the batch size and predictions_per_sample on the same dimension to be able to use the patches_to_images tool.
    # Also rearrange patches back to lists from the "small images".
    # Shape of relevances after reshape:
    # (batch_size * predictions_per_sample, out_height, out_width, kernel_height * kernel_width * input_channels)
    relevances = tf.reshape(
        relevances,
        (batch_size_times_predictions_per_sample, output_height, output_width,
         kernel_size[1] * kernel_size[2] * input_channels))

    # Reconstruct the shape of the input, thereby summing the relevances for each individual feature
    R_new = lrp_util.patches_to_images(
        relevances, batch_size_times_predictions_per_sample, input_height,
        input_width, input_channels, output_height, output_width,
        kernel_size[1], kernel_size[2], strides[1], strides[2], padding)

    # Reshape the relevances back to having batch_size and predictions_per_sample as the first two dimensions
    # (batch_size, predictions_per_sample, input_height, input_width, input_channels) rather than
    # (batch_size * predictions_per_sample, input_height, input_width, input_channels)
    R_new = tf.reshape(R_new, (batch_size, predictions_per_sample,
                               input_height, input_width, input_channels))

    # Report handled operations
    router.mark_operation_handled(current_operation)

    # Forward the calculated relevance to the input of the convolution
    router.forward_relevance_to_operation(R_new, current_operation,
                                          max_pool_input.op)