Ejemplo n.º 1
0
def conv_2d(inputs, ksize, nchannel, stride, padding, data_format="NHWC"):
    """conv 2d, NHWC"""

    if data_format == "NHWC":
        fanin = get_tensor_shape(inputs)[-1]
        strides = [1, stride, stride, 1]
    else:
        fanin = get_tensor_shape(inputs)[1]
        strides = [1, 1, stride, stride]
    W, b = get_W_b_conv2d(ksize=ksize, fanin=fanin, fanout=nchannel)
    conv = tf.nn.conv2d(
        inputs, W, strides=strides,
        padding=padding, data_format=data_format)

    return tf.nn.bias_add(conv, b, data_format=data_format)
Ejemplo n.º 2
0
def train_step(tokens, length, training, dropout):
    inputs = tokens[:, :-1]
    encoded = encoder(inputs=inputs,
                      length=length,
                      dropout=dropout,
                      attention_dropout=dropout,
                      use_2d=use_2d)
    logits = decoder(inputs=inputs,
                     encoded=encoded,
                     dropout=dropout,
                     attention_dropout=dropout,
                     use_2d=use_2d,
                     encoded_length=encoded_length)
    ln = utils.get_tensor_shape(tokens)[1]
    mask = tf.sequence_mask(length + 1, ln)
    loss = utils.loss(tokens, logits, mask, use_2d=use_2d)

    def true_fn():
        weights = get_weights()
        grads = tf.gradients(loss, weights)
        opt = optimizer.apply_gradients(zip(grads, weights))
        with tf.control_dependencies([opt]):
            opt = tf.zeros([], tf.bool)
        return opt

    def false_fn():
        opt = tf.zeros([], tf.bool)
        return opt

    opt = tf.cond(training, true_fn, false_fn)
    return loss, opt
def train(tokens, length, training, dropout):
    training = tf.squeeze(training)
    dropout = tf.squeeze(dropout)
    tokens = tf.squeeze(tokens, 0)
    length = tf.squeeze(length, 0)
    steps = utils.get_tensor_shape(tokens)[0]

    def body(step, loss, opt):
        _tokens, _length = tf.gather(tokens, step), tf.gather(length, step)
        with tf.control_dependencies([opt]):
            _loss, opt = train_step(_tokens, _length, training, dropout)
        loss = loss + _loss
        step = step + 1
        return step, loss, opt

    def cond(step, loss, opt):
        return step < steps

    step = tf.zeros([], tf.int32)
    loss = tf.zeros([])
    opt = tf.zeros([], tf.bool)
    var_loops = [step, loss, opt]
    step, loss, opt = tf.contrib.tpu.while_loop(cond, body, var_loops)
    loss = loss / tf.cast(steps, dtype=loss.dtype)
    return loss, opt
Ejemplo n.º 4
0
    def _build_st(self, module, xyz, cs, names, out_size,
                  reduce_ratio=None, do_rotate=False, do_reverse=False):
        """Subroutine for building spatial transformer"""

        for name in names:
            with tf.variable_scope(module):
                cur_inputs = self.inputs["patch"][name]
                # Get xy and cs
                if xyz is not None:
                    _xyz = xyz[name]["xyz"]
                else:
                    batch_size = tf.shape(cur_inputs)[0]
                    _xyz = tf.zeros((batch_size, 3))
                if cs is not None:
                    _cs = cs[name]["cs"]
                else:
                    _cs = None
                # transform coordinates
                # if do_rotate:
                #     _xyz[:2] = self.transform_xyz(_xyz,
                #                                   _cs,
                #                                   config.batch_size,
                #                                   reverse=do_reverse)
                input_size = get_tensor_shape(cur_inputs)[2]
                if reduce_ratio is None:
                    reduce_ratio = float(out_size) / float(input_size)
                # apply the spatial transformer
                self.outputs[module][name] = transformer(
                    U=cur_inputs,
                    # Get the output from the keypoint layer
                    theta=make_theta(xyz=_xyz, cs=_cs, rr=reduce_ratio),
                    out_size=(out_size, out_size),
                )
Ejemplo n.º 5
0
 def seq_encoder(self, input, u, cell_size, length, scope):
     input_shape = utils.get_tensor_shape(input)
     input = tf.reshape(input,
                        [np.prod(input_shape[:-2])] + input_shape[-2:])
     length = tf.reshape(length, [-1])
     rep = self.gru_encoder(input, cell_size, length, scope)
     rep = tf.reshape(rep, input_shape[:-1] + [cell_size * 2])
     return self.atten_encoder(u, rep)
Ejemplo n.º 6
0
def ghh(inputs, num_in_sum, num_in_max, data_format="NHWC"):
    """GHH layer

    LATER: Make it more efficient

    """

    # Assert NHWC
    assert data_format == "NHWC"

    # Check validity
    inshp = get_tensor_shape(inputs)
    num_channels = inshp[-1]
    pool_axis = len(inshp) - 1
    assert (num_channels % (num_in_sum * num_in_max)) == 0

    # Abuse cur_in
    cur_in = inputs

    # # Okay the maxpooling and avgpooling functions do not like weird
    # # pooling. Just reshape to avoid this issue.
    # inshp = get_tensor_shape(inputs)
    # numout = int(inshp[1] / (num_in_sum * num_in_max))
    # cur_in = tf.reshape(cur_in, [
    #     inshp[0], numout, num_in_sum, num_in_max, inshp[2], inshp[3]
    # ])

    # Reshaping does not work for undecided input sizes. use split instead
    cur_ins_to_max = tf.split(cur_in,
                              num_channels // num_in_max,
                              axis=pool_axis)

    # Do max and concat them back
    cur_in = tf.concat([
        tf.reduce_max(cur_ins, axis=pool_axis, keep_dims=True)
        for cur_ins in cur_ins_to_max
    ],
                       axis=pool_axis)

    # Create delta
    delta = (1.0 - 2.0 * (np.arange(num_in_sum) % 2)).astype("float32")
    delta = tf.reshape(delta, [1] * (len(inshp) - 1) + [num_in_sum])

    # Again, split into multiple pieces
    cur_ins_to_sum = tf.split(cur_in,
                              num_channels // (num_in_max * num_in_sum),
                              axis=pool_axis)

    # Do delta multiplication, sum, and concat them back
    cur_in = tf.concat([
        tf.reduce_sum(cur_ins * delta, axis=pool_axis, keep_dims=True)
        for cur_ins in cur_ins_to_sum
    ],
                       axis=pool_axis)

    return cur_in
Ejemplo n.º 7
0
def divide_batch(x, n):
    shape = get_tensor_shape(x)
    batch_size = shape[0]
    if isinstance(shape, list):
        new_shape = [n, batch_size // n] + shape[1:]
    else:
        new_batch_shape = tf.convert_to_tensor([n, batch_size // n])
        new_shape = tf.concat([new_batch_shape, shape[1:]], 0)
    result = tf.reshape(x, new_shape)
    return result
Ejemplo n.º 8
0
 def atten_encoder(self, Q, K):
     #Q ...*seq_len_q*F
     #K=V ...*seq_len_k*F
     K_shape = utils.get_tensor_shape(K)
     K = tf.layers.dense(K, K_shape[-1], activation=tf.nn.tanh)
     # ======================================================
     # Q=tf.transpose(Q,[-1,-2])
     # scores=tf.map_fn(lambda x:x@Q,K,dtype=tf.float32)
     # -------------another implementation-------------------
     scores = tf.reduce_sum(K * Q, -1, keepdims=True)
     #=======================================================
     scores = tf.nn.softmax(scores, -2)
     return tf.reduce_sum(scores * K, -2), tf.squeeze(scores)
Ejemplo n.º 9
0
def fc(inputs, fanout):
    """fully connected, NC """

    inshp = get_tensor_shape(inputs)
    fanin = np.prod(inshp[1:])

    # Flatten input if needed
    if len(inshp) > 2:
        inputs = tf.reshape(inputs, (inshp[0], fanin))

    W, b = get_W_b_fc(fanin=fanin, fanout=fanout)
    mul = tf.matmul(inputs, W)

    return tf.nn.bias_add(mul, b)
Ejemplo n.º 10
0
def conv_2d_trans(inputs, ksize, nchannel, stride, padding, data_format="NHWC"):
    """conv 2d, transposed, NHWC"""

    assert(padding == "VALID")

    inshp = tf.shape(inputs)
    if data_format == "NHWC":
        fanin = get_tensor_shape(inputs)[-1]
        strides = [1, stride, stride, 1]
        output_shape = tf.stack(
            [inshp[0],
             inshp[1] * int(stride),  # + max(ksize - stride, 0),
             inshp[2] * int(stride),  # + max(ksize - stride, 0),
             nchannel])
    else:
        fanin = get_tensor_shape(inputs)[1]
        strides = [1, 1, stride, stride]
        output_shape = tf.stack(
            [inshp[0],
             nchannel,
             inshp[2] * int(stride),  # + max(ksize - stride, 0),
             inshp[3] * int(stride),  # + max(ksize - stride, 0)
             ])

    with tf.variable_scope("W"):
        W, _ = get_W_b_conv2d(ksize=ksize, fanin=nchannel, fanout=fanin)
    with tf.variable_scope("b"):
        _, b = get_W_b_conv2d(ksize=ksize, fanin=fanin, fanout=nchannel)

    deconv2dres = tf.nn.conv2d_transpose(
        inputs, W, output_shape, strides=strides, padding=padding,
        data_format=data_format)

    deconv2dres = tf.reshape(deconv2dres, output_shape)

    return tf.nn.bias_add(deconv2dres, b, data_format=data_format)
Ejemplo n.º 11
0
 def call(self, inputs, training=None, dropout=0.1):
     if training is None:
         training = True
     training = tf.convert_to_tensor(training)
     x = tf.expand_dims(inputs, -1)
     x = self.conv1(x)
     x = self.max_pool(x)
     x = tf.keras.activations.relu(x)
     x = self.conv2(x)
     x = self.max_pool(x)
     x = tf.keras.activations.relu(x)
     shape = get_tensor_shape(x)
     new_shape = (shape[0], shape[1] * shape[2] * shape[3])
     x = tf.reshape(x, new_shape)
     x = tf.cond(training, lambda: tf.nn.dropout(x, dropout), lambda: x)
     x = self.dense1(x)
     x = tf.cond(training, lambda: tf.nn.dropout(x, dropout), lambda: x)
     x = self.dense2(x)
     return x
Ejemplo n.º 12
0
def process(inputs, bypass, name, skip, config, is_training):
    """WRITEME.

    LATER: Clean up

    inputs: input to the network
    bypass: gt to by used when trying to bypass
    name: name of the siamese branch
    skip: whether to apply the bypass information

    """

    # let's look at the inputs that get fed into this layer except when we are
    # looking at the whole image
    if name != "img":
        image_summary_nhwc(name + "-input", inputs)

    if skip:
        return bypass_kp(bypass)

    # we always expect a dictionary as return value to be more explicit
    res = {}

    # now abuse cur_in so that we can simply copy paste
    cur_in = inputs

    # lets apply batch normalization on the input - we did not normalize the
    # input range!
    # with tf.variable_scope("input-bn"):
    #     if config.use_input_batch_norm:
    #         cur_in = batch_norm(cur_in, training=is_training)

    with tf.variable_scope("conv-ghh-1"):
        nu = 1
        ns = 4
        nm = 4
        cur_in = conv_2d(cur_in, config.kp_filter_size, nu * ns * nm, 1,
                         "VALID")
        # batch norm on the output of convolutions!
        # if config.use_batch_norm:
        #     cur_in = batch_norm(cur_in, training=is_training)
        cur_in = ghh(cur_in, ns, nm)

    res["scoremap-uncut"] = cur_in

    # ---------------------------------------------------------------------
    # Check how much we need to cut
    kp_input_size = config.kp_input_size
    patch_size = get_patch_size_no_aug(config)
    desc_input_size = config.desc_input_size
    rf = float(kp_input_size) / float(patch_size)

    input_shape = get_tensor_shape(inputs)
    uncut_shape = get_tensor_shape(cur_in)
    req_boundary = np.ceil(rf * np.sqrt(2) * desc_input_size / 2.0).astype(int)
    cur_boundary = (input_shape[2] - uncut_shape[2]) // 2
    crop_size = req_boundary - cur_boundary

    # Stop building the network outputs if we are building for the full image
    if name == "img":
        return res

    # # Debug messages
    # resized_shape = get_tensor_shape(inputs)
    # print(' -- kp_info: output score map shape {}'.format(uncut_shape))
    # print(' -- kp_info: input size after resizing {}'.format(resized_shape[2]))
    # print(' -- kp_info: output score map size {}'.format(uncut_shape[2]))
    # print(' -- kp info: required boundary {}'.format(req_boundary))
    # print(' -- kp info: current boundary {}'.format(cur_boundary))
    # print(' -- kp_info: additional crop size {}'.format(crop_size))
    # print(' -- kp_info: additional crop size {}'.format(crop_size))
    # print(' -- kp_info: final cropped score map size {}'.format(
    #     uncut_shape[2] - 2 * crop_size))
    # print(' -- kp_info: movement ratio will be {}'.format((
    #     float(uncut_shape[2] - 2.0 * crop_size) /
    #     float(kp_input_size - 1))))

    # Crop center
    cur_in = cur_in[:, crop_size:-crop_size, crop_size:-crop_size, :]
    res["scoremap"] = cur_in

    # ---------------------------------------------------------------------
    # Mapping layer to x,y,z
    com_strength = config.kp_com_strength
    # eps = 1e-10
    scoremap_shape = get_tensor_shape(cur_in)

    od = len(scoremap_shape)
    # CoM to get the coordinates
    pos_array_x = tf.range(scoremap_shape[2], dtype=tf.float32)
    pos_array_y = tf.range(scoremap_shape[1], dtype=tf.float32)

    out = cur_in
    max_out = tf.reduce_max(out, axis=list(range(1, od)), keep_dims=True)
    o = tf.exp(com_strength * (out - max_out))  # + eps
    sum_o = tf.reduce_sum(o, axis=list(range(1, od)), keep_dims=True)
    x = tf.reduce_sum(o * tf.reshape(pos_array_x, [1, 1, -1, 1]),
                      axis=list(range(1, od)),
                      keep_dims=True) / sum_o
    y = tf.reduce_sum(o * tf.reshape(pos_array_y, [1, -1, 1, 1]),
                      axis=list(range(1, od)),
                      keep_dims=True) / sum_o

    # Remove the unecessary dimensions (i.e. flatten them)
    x = tf.reshape(x, (-1, ))
    y = tf.reshape(y, (-1, ))

    # --------------
    # Turn x, and y into range -1 to 1, where the patch size is
    # mapped to -1 and 1
    orig_patch_width = (scoremap_shape[2] +
                        np.cast["float32"](req_boundary * 2.0))
    orig_patch_height = (scoremap_shape[1] +
                         np.cast["float32"](req_boundary * 2.0))

    x = ((x + np.cast["float32"](req_boundary)) / np.cast["float32"](
        (orig_patch_width - 1.0) * 0.5) - np.cast["float32"](1.0))
    y = ((y + np.cast["float32"](req_boundary)) / np.cast["float32"](
        (orig_patch_height - 1.0) * 0.5) - np.cast["float32"](1.0))

    # --------------
    # No movement in z direction
    z = tf.zeros_like(x)

    res["xyz"] = tf.stack([x, y, z], axis=1)

    # ---------------------------------------------------------------------
    # Mapping layer to x,y,z
    res["score"] = softmax(
        res["scoremap"],
        axis=list(range(1, od)),
        softmax_strength=config.kp_scoremap_softmax_strength)

    return res
Ejemplo n.º 13
0
    def __init__(self, config, word_embeddings, law_input, law_doc_length,
                 law_sent_length):
        self.lstm_size = config.lstm_size
        self.lstm_law_size = config.lstm_law_size
        self.k_laws = config.k_laws
        self.doc_len = config.doc_len
        self.sent_len = config.sent_len
        self.law_sent_len = config.law_sent_len
        self.law_doc_len = config.law_doc_len
        self.batch_size = config.batch_size
        self.n_law = config.n_law
        self.n_class = config.n_class
        self.keep_prob = tf.placeholder(tf.float32, [])
        self.input = tf.placeholder(
            tf.int32, [self.batch_size, self.doc_len, self.sent_len])
        self.input_doc_length = tf.placeholder(tf.int32, [self.batch_size])
        self.input_sent_length = tf.placeholder(
            tf.int32, [self.batch_size, self.doc_len])
        self.label = tf.placeholder(tf.float32,
                                    [self.batch_size, self.n_class])
        self.law_label = tf.placeholder(tf.float32,
                                        [self.batch_size, self.n_law])
        self.embedding = tf.Variable(word_embeddings, trainable=False)
        inputs = tf.nn.embedding_lookup(self.embedding, self.input)
        laws = tf.nn.embedding_lookup(self.embedding, law_input)

        inputs_shape = utils.get_tensor_shape(inputs)
        self.n_fact_feat = self.n_law_feat = inputs_shape[-1]

        with tf.name_scope('get_laws'):
            inputs_2d = tf.reshape(inputs, [
                self.batch_size,
                self.doc_len * self.sent_len * self.n_fact_feat
            ])
            law_index = layers.fully_connected(inputs_2d,
                                               self.n_law,
                                               activation_fn=tf.nn.softmax)
            # law_index=tf.nn.top_k(law_index,self.k_laws)
            self.law_index = tf.contrib.framework.argsort(law_index,
                                                          -1)[:, -self.k_laws:]
            self.law_index = tf.reshape(self.law_index,
                                        [self.batch_size, self.k_laws])
            self.laws = tf.nn.embedding_lookup(laws, self.law_index)
            self.k_law_label = tf.map_fn(lambda x: tf.gather(x[0], x[1]),
                                         (self.law_label, self.law_index),
                                         dtype=tf.float32)
            law_doc_length = tf.gather(law_doc_length, self.law_index)
            law_sent_length = tf.gather(law_sent_length, self.law_index)

        with tf.name_scope('fact_encoder'):
            self.n_fact_feat = self.lstm_size * 2
            # inputs=tf.reshape(inputs, [self.batch_size * self.doc_len, self.sent_len, self.n_fact_feat])
            u_fw = tf.get_variable('u_fw',
                                   shape=[1, self.n_fact_feat],
                                   initializer=layers.xavier_initializer())
            u_fs = tf.get_variable('u_fs',
                                   shape=[1, self.n_fact_feat],
                                   initializer=layers.xavier_initializer())
            sent_encoded, _ = self.seq_encoder(inputs, u_fw, config.lstm_size,
                                               self.input_sent_length,
                                               'fact_sent')
            # sent_encoded=tf.reshape(sent_encoded, [self.batch_size, self.doc_len, self.n_fact_feat])
            self.d_f, _ = self.seq_encoder(sent_encoded, u_fs,
                                           config.lstm_size,
                                           self.input_doc_length, 'fact_doc')

        with tf.name_scope('law_encoder'):
            self.n_law_feat = self.lstm_law_size * 2
            u_aw = tf.reshape(tf.layers.dense(self.d_f, self.n_law_feat),
                              [self.batch_size, 1, 1, 1, self.n_law_feat])
            u_as = tf.reshape(tf.layers.dense(self.d_f, self.n_law_feat),
                              [self.batch_size, 1, 1, self.n_law_feat])
            # laws=tf.reshape(self.laws, [self.batch_size * self.k_laws * self.law_doc_len, self.law_sent_len, self.n_law_feat // 2])
            # law_sent_length=tf.reshape(law_sent_length,[self.batch_size*self.k_laws*self.law_doc_len])
            sent_encoded, _ = self.seq_encoder(self.laws, u_aw,
                                               config.lstm_law_size,
                                               law_sent_length, 'law_sent')
            # sent_encoded=tf.reshape(sent_encoded, [self.batch_size * self.k_laws, self.doc_len, self.n_law_feat])
            law_repr, _ = self.seq_encoder(sent_encoded, u_as,
                                           config.lstm_law_size,
                                           law_doc_length, 'law_doc')
            # law_repr=tf.reshape(law_repr, [self.batch_size, self.k_laws, self.n_law_feat])

        with tf.name_scope('law_aggregator'):
            u_ad = tf.reshape(tf.layers.dense(self.d_f, self.n_fact_feat),
                              [self.batch_size, 1, self.n_law_feat])
            self.d_a, self.law_score = self.seq_encoder(
                law_repr, u_ad, config.lstm_law_size,
                [self.k_laws] * self.batch_size, 'aggregator')

        with tf.name_scope('softmax'):
            self.outputs1 = layers.fully_connected(
                tf.concat([self.d_f, self.d_a], -1), config.fc_size1)
            self.outputs2 = layers.fully_connected(self.outputs1,
                                                   config.fc_size2)
            self.outputs = layers.fully_connected(self.outputs2,
                                                  self.n_class,
                                                  activation_fn=None)

        # with tf.name_scope('lstm'):
        #     lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(self.lstm_size, forget_bias=0.0)
        #     lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob=self.keep_prob)
        #     cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * config.num_layers)
        #     self.initial_state = cell.zero_state(self.batch_size, tf.float32)
        #     with tf.variable_scope('context'):
        #         outputs, _ = tf.nn.dynamic_rnn(cell, inputs, initial_state=self.initial_state,
        #                                        sequence_length=self.input_length)
        #
        # output = tf.expand_dims(tf.reshape(outputs, [self.batch_size, -1, self.lstm_size]), -1)
        #
        # with tf.name_scope("lstm_maxpool"):
        #     output_pooling = tf.nn.max_pool(output,
        #                                     ksize=[1, self.doc_len, 1, 1],
        #                                     strides=[1, 1, 1, 1],
        #                                     padding='VALID',
        #                                     name="pool")

        # with tf.name_scope('lstm_law') as scope:
        #
        #     lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(self.lstm_law_size, forget_bias=0.0)
        #     lstm_cell = tf.nn.rnn_cell.DropoutWrapper(lstm_cell, output_keep_prob=self.keep_prob)
        #     cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * config.num_layers)
        #     self.initial_state = cell.zero_state(self.batch_size*self.k_laws, tf.float32)
        #     laws_extracted=tf.reshape(self.laws,[self.batch_size*self.k_laws,self.doc_len,self.embedding.get_shape().as_list()[-1]])
        #     with tf.variable_scope('law'):
        #         outputs_law, _ = tf.nn.dynamic_rnn(cell, laws_extracted, initial_state=self.initial_state,
        #                                        sequence_length=tf.reshape(law_length,[-1]))
        #     outputs_law=tf.reshape(outputs_law,[self.batch_size,self.k_laws,self.doc_len,self.lstm_law_size])

        self.prediction = tf.where(
            tf.nn.softmax(self.outputs) > config.threshold)
        self.law_prediction = tf.where(self.law_score > config.law_threshold,
                                       tf.to_float(self.law_index),
                                       tf.zeros_like(self.law_score))
        self.loss_main = tf.losses.log_loss(self.norm_sum(self.label),
                                            tf.nn.softmax(self.outputs))
        self.loss_law = tf.losses.log_loss(self.norm_sum(self.k_law_label),
                                           self.law_score)
        loss_reg = tf.losses.get_regularization_loss() * config.l2_ratio
        self.loss = self.loss_main + loss_reg + config.attention_loss_ratio * self.loss_law
Ejemplo n.º 14
0
def norm_spatial_subtractive(inputs, sub_kernel, data_format="NHWC"):
    """Performs the spatial subtractive normalization

    Parameters
    ----------

    inputs: tensorflow 4D tensor, NHWC format

    input to the network

    sub_kernel: numpy.ndarray, 2D matrix

    the subtractive normalization kernel

    """

    raise NotImplementedError(
        "This function is buggy! don't use before extensive debugging!")

    # ----------
    # Normalize kernel.
    # Note that unlike Torch, we don't divide the kernel here. We divide
    # when it is fed to the convolution, since we use it to generate the
    # coefficient map.
    kernel = sub_kernel.astype("float32")
    norm_kernel = (kernel / np.sum(kernel))

    # ----------
    # Compute the adjustment coef.
    # This allows our mean computation to compensate for the border area,
    # where you have less terms adding up. Torch used convolution with a
    # ``one'' image, but since we do not want the library to depend on
    # other libraries with convolutions, we do it manually here.
    input_shape = get_tensor_shape(inputs)
    assert len(input_shape) == 4
    if data_format == "NHWC":
        coef = np.ones(input_shape[1:3], dtype="float32")
    else:
        coef = np.ones(input_shape[2:], dtype="float32")
    pad_x = norm_kernel.shape[1] // 2
    pad_y = norm_kernel.shape[0] // 2

    # Corners
    # for the top-left corner
    tl_cumsum_coef = np.cumsum(np.cumsum(norm_kernel[::-1, ::-1], axis=0),
                               axis=1)[::1, ::1]
    coef[:pad_y + 1, :pad_x + 1] = tl_cumsum_coef[pad_y:, pad_x:]
    # for the top-right corner
    tr_cumsum_coef = np.cumsum(np.cumsum(norm_kernel[::-1, ::1], axis=0),
                               axis=1)[::1, ::-1]
    coef[:pad_y + 1, -pad_x - 1:] = tr_cumsum_coef[pad_y:, :pad_x + 1]
    # for the bottom-left corner
    bl_cumsum_coef = np.cumsum(np.cumsum(norm_kernel[::1, ::-1], axis=0),
                               axis=1)[::-1, ::1]
    coef[-pad_y - 1:, :pad_x + 1] = bl_cumsum_coef[:pad_y + 1, pad_x:]
    # for the bottom-right corner
    br_cumsum_coef = np.cumsum(np.cumsum(norm_kernel[::1, ::1], axis=0),
                               axis=1)[::-1, ::-1]
    coef[-pad_y - 1:, -pad_x - 1:] = br_cumsum_coef[:pad_y + 1, :pad_x + 1]

    # Sides
    tb_slice = slice(pad_y + 1, -pad_y - 1)
    # for the left side
    fill_value = tl_cumsum_coef[-1, pad_x:]
    coef[tb_slice, :pad_x + 1] = fill_value.reshape([1, -1])
    # for the right side
    fill_value = br_cumsum_coef[0, :pad_x + 1]
    coef[tb_slice, -pad_x - 1:] = fill_value.reshape([1, -1])
    lr_slice = slice(pad_x + 1, -pad_x - 1)
    # for the top side
    fill_value = tl_cumsum_coef[pad_y:, -1]
    coef[:pad_y + 1, lr_slice] = fill_value.reshape([-1, 1])
    # for the right side
    fill_value = br_cumsum_coef[:pad_y + 1, 0]
    coef[-pad_y - 1:, lr_slice] = fill_value.reshape([-1, 1])

    # # code for validation of above
    # img = np.ones_like(input, dtype='float32')
    # import cv2
    # coef_cv2 = cv2.filter2D(img, -1, norm_kernel,
    #                         borderType=cv2.BORDER_CONSTANT)

    # ----------
    # Extract convolutional mean
    # Make filter a c01 filter by repeating. Note that we normalized above
    # with the number of repetitions we are going to do.
    if data_format == "NHWC":
        norm_kernel = np.tile(norm_kernel, [input_shape[-1], 1, 1])
    else:
        norm_kernel = np.tile(norm_kernel, [input_shape[1], 1, 1])
    # Re-normlize the kernel so that the sum is one.
    norm_kernel /= np.sum(norm_kernel)
    # add another axis in from to make oc01 filter, where o is the number
    # of output dimensions (in our case, 1!)
    norm_kernel = norm_kernel[np.newaxis, ...]
    # # To pad with zeros, half the size of the kernel (only for 01 dims)
    # border_mode = tuple(s // 2 for s in norm_kernel.shape[2:])
    # Convolve the mean filter. Results in shape of (batch_size,
    # input_shape[1], input_shape[2], 1).
    # For tensorflow, the kernel shape is 01co, which is different.... why?!
    conv_mean = tf.nn.conv2d(
        inputs,
        norm_kernel.astype("float32").transpose(2, 3, 1, 0),
        strides=[1, 1, 1, 1],
        padding="SAME",
        data_format=data_format,
    )

    # ----------
    # Adjust convolutional mean with precomputed coef
    # This is to prevent border values being too small.
    if data_format == "NHWC":
        coef = coef[None][..., None].astype("float32")
    else:
        coef = coef[None, None].astype("float32")
    adj_mean = conv_mean / coef
    # # Make second dimension broadcastable as we are going to
    # # subtract for all channels.
    # adj_mean = T.addbroadcast(adj_mean, 1)

    # ----------
    # Subtract mean
    sub_normalized = inputs - adj_mean

    # # line for debugging
    # test = theano.function(inputs=[input], outputs=[sub_normalized])

    return sub_normalized