def test_nhwc(self, padding):
        strides = self.strides_nhwc
        ksize = self.ksize_nhwc
        output = self.output_nhwc
        np_nhwc = self.grad_nhwc[padding]
        if padding == "VALID":
            grad = tf.placeholder(tf.float32, shape=(128, 112, 74, 3))
        elif padding == "SAME":
            grad = tf.placeholder(tf.float32, shape=(128, 112, 75, 3))

        with self.device:
            a = max_pool_grad(
                self.input_nhwc,
                output,
                grad,
                ksize,
                strides,
                padding=padding,
                data_format="NHWC")
            with self.session as sess:
                result = sess.run(a, feed_dict={grad: np_nhwc})

        with tf.device('/cpu:0'):
            b = max_pool_grad(
                self.input_nhwc,
                output,
                grad,
                ksize,
                strides,
                padding=padding,
                data_format="NHWC")
            with self.session as sess:
                expected = sess.run(b, feed_dict={grad: np_nhwc})

        np.testing.assert_allclose(result, expected, rtol=5e-7)
 def testDirectNotUseOverlapping(self):
   for num_batches in [1, 3]:
     for row_window_size in [2, 5]:
       for col_window_size in [2, 4]:
         num_rows = row_window_size * 5
         num_cols = col_window_size * 7
         for num_channels in [1, 2]:
           input_shape = (num_batches, num_rows, num_cols, num_channels)
           with self.cached_session() as _:
             input_tensor = constant_op.constant(
                 self._GenerateUniqueRandomInputTensor(input_shape))
             window_size = [1, row_window_size, col_window_size, 1]
             stride_size = [1, row_window_size, col_window_size, 1]
             padding = "VALID"
             output_tensor = nn_ops.max_pool(input_tensor, window_size,
                                             stride_size, padding)
             output_data = self.evaluate(output_tensor)
             output_backprop = self._PRNG.randint(100, size=output_data.shape)
             input_backprop_tensor = gen_nn_ops.max_pool_grad(
                 input_tensor, output_tensor, output_backprop, window_size,
                 stride_size, padding)
             input_backprop = self.evaluate(input_backprop_tensor)
             row_seq = list(range(0, num_rows + 1, row_window_size))
             col_seq = list(range(0, num_cols + 1, col_window_size))
             fmp_input_backprop_tensor = gen_nn_ops.fractional_max_pool_grad(
                 input_tensor,
                 output_tensor,
                 output_backprop,
                 row_seq,
                 col_seq,
                 overlapping=False)
             fmp_input_backprop = self.evaluate(fmp_input_backprop_tensor)
             self.assertShapeEqual(input_backprop, fmp_input_backprop_tensor)
             self.assertAllClose(input_backprop, fmp_input_backprop)
def tf_model(padding):
    orig_in = tf.placeholder(tf.float32, shape=[N, C, H, W])
    if padding == "VALID":
        grad = tf.placeholder(tf.float32, shape=valid_shape)
        orig_out = tf.placeholder(tf.float32, shape=valid_shape)
    elif padding == "SAME":
        grad = tf.placeholder(tf.float32, shape=same_shape)
        orig_out = tf.placeholder(tf.float32, shape=same_shape)

    # cast the input dtype to bfloat16 for TF
    orig_in_c = tf.cast(orig_in, tf.bfloat16)
    orig_out_c = tf.cast(orig_out, tf.bfloat16)
    grad_c = tf.cast(grad, tf.bfloat16)

    # transpose to NHWC
    orig_in_t = tf.transpose(orig_in_c, (0, 2, 3, 1))
    orig_out_t = tf.transpose(orig_out_c, (0, 2, 3, 1))
    grad_t = tf.transpose(grad_c, (0, 2, 3, 1))

    out = max_pool_grad(orig_in_t,
                        orig_out_t,
                        grad_t,
                        ksize_nhwc,
                        stride_nhwc,
                        padding=padding,
                        data_format="NHWC")

    # cast the output dtype back to float32
    output = tf.cast(out, tf.float32)

    # transpose to NCHW
    output_nchw = tf.transpose(output, (0, 3, 1, 2))
    return output_nchw, orig_in, orig_out, grad
示例#4
0
def tf_model(padding):
    orig_in = tf.placeholder(tf.float32, shape=[N, H, W, C])
    if padding == "VALID":
        grad = tf.placeholder(tf.float32, shape=valid_shape)
        orig_out = tf.placeholder(tf.float32, shape=valid_shape)
    elif padding == "SAME":
        grad = tf.placeholder(tf.float32, shape=same_shape)
        orig_out = tf.placeholder(tf.float32, shape=same_shape)

    # cast the input dtype to bfloat16 for TF
    orig_in_c = tf.cast(orig_in, tf.bfloat16)
    orig_out_c = tf.cast(orig_out, tf.bfloat16)
    grad_c = tf.cast(grad, tf.bfloat16)

    out = max_pool_grad(
        orig_in_c,
        orig_out_c,
        grad_c,
        ksize_nhwc,
        stride_nhwc,
        padding=padding,
        data_format="NHWC")

    # cast the output dtype back to float32
    output = tf.cast(out, tf.float32)
    return output, orig_in, orig_out, grad
示例#5
0
 def bp(self, AI, AO, DO, cache):
     DI = gen_nn_ops.max_pool_grad(grad=DO,
                                   orig_input=AI,
                                   orig_output=AO,
                                   ksize=self.ksize,
                                   strides=self.strides,
                                   padding=self.padding)
     return {'dout': DI, 'cache': {}}, []
示例#6
0
def _MaxPoolGrad(op, grad):
    return gen_nn_ops.max_pool_grad(op.inputs[0],
                                    op.outputs[0],
                                    grad,
                                    op.get_attr("ksize"),
                                    op.get_attr("strides"),
                                    padding=op.get_attr("padding"),
                                    data_format=op.get_attr("data_format"))
示例#7
0
 def backward(self, AI, AO, DO):
     grad = gen_nn_ops.max_pool_grad(grad=DO,
                                     orig_input=AI,
                                     orig_output=AO,
                                     ksize=self.ksize,
                                     strides=self.strides,
                                     padding=self.padding)
     return grad
 def test_on_ng(sess):
     a = max_pool_grad(self.input_nchw,
                       output,
                       grad,
                       ksize,
                       strides,
                       padding=padding,
                       data_format="NCHW")
     return sess.run(a, feed_dict={grad: g_nchw})
示例#9
0
 def dfa_backward(self, AI, AO, E, DO):
     grad = gen_nn_ops.max_pool_grad(grad=DO,
                                     orig_input=AI,
                                     orig_output=AO,
                                     ksize=self.ksize,
                                     strides=self.strides,
                                     padding=self.padding)
     # grad = tf.Print(grad, [tf.shape(grad), tf.count_nonzero(tf.equal(grad, 1)), tf.count_nonzero(tf.equal(grad, 2)), tf.count_nonzero(tf.equal(grad, 3)), tf.count_nonzero(tf.equal(grad, 4)), tf.count_nonzero(tf.equal(grad, 5))], message="", summarize=1000)
     return grad
示例#10
0
def _MaxPoolGrad(op, grad):
  return gen_nn_ops.max_pool_grad(
      op.inputs[0],
      op.outputs[0],
      grad,
      op.get_attr("ksize"),
      op.get_attr("strides"),
      padding=op.get_attr("padding"),
      data_format=op.get_attr("data_format"))
示例#11
0
    def test_nchw(self, padding):
        strides = self.strides_nchw
        ksize = self.ksize_nchw
        output = self.output_nchw
        np_nchw = self.grad_nchw[padding]
        if padding == "VALID":
            grad = tf.placeholder(tf.float32, shape=(128, 3, 112, 74))
        elif padding == "SAME":
            grad = tf.placeholder(tf.float32, shape=(128, 3, 112, 75))

        with self.device:
            a = max_pool_grad(
                self.input_nchw,
                output,
                grad,
                ksize,
                strides,
                padding=padding,
                data_format="NCHW")
            with self.session as sess:
                result = sess.run(a, feed_dict={grad: np_nchw})
        # To validate on the CPU side we will need to run in NHWC, because the CPU
        # implementation of avgpool backprop does not support NCHW. We will
        # transpose on the way in and on the way out
        with tf.device('/cpu:0'):
            grad = tf.transpose(grad, NCHW_TO_NHWC)
            np_nhwc = self.grad_nhwc[padding]
            output = self.output_nhwc
            ksize = self.ksize_nhwc
            strides = self.strides_nhwc
            b = max_pool_grad(
                self.input_nhwc,
                output,
                grad,
                ksize,
                strides,
                padding=padding,
                data_format="NHWC")
            b = tf.transpose(b, NHWC_TO_NCHW)
            with self.session as sess:
                expected = sess.run(b, feed_dict={grad: np_nhwc})

        np.testing.assert_allclose(result, expected, rtol=5e-7)
示例#12
0
def _MaxPoolGradGradGrad(op, grad):
    return (array_ops.zeros_like(op.inputs[0]),
            array_ops.zeros_like(op.inputs[1]),
            gen_nn_ops.max_pool_grad(op.inputs[0],
                                     op.inputs[1],
                                     grad,
                                     op.get_attr("ksize"),
                                     op.get_attr("strides"),
                                     padding=op.get_attr("padding"),
                                     data_format=op.get_attr("data_format")))
示例#13
0
def _MaxPoolGradGradGrad(op, grad):
  return (array_ops.zeros(
      shape=array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype),
          array_ops.zeros(
              shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
          gen_nn_ops.max_pool_grad(
              op.inputs[0],
              op.inputs[1],
              grad,
              op.get_attr("ksize"),
              op.get_attr("strides"),
              padding=op.get_attr("padding"),
              data_format=op.get_attr("data_format")))
示例#14
0
    def testFwdAndBwdMaxPool(self):
        input = np.arange(16).reshape(1, 4, 4, 1)
        output_grad = np.full((1, 2, 2, 1), 0.1)

        with ops.device("/device:IPU:0"):
            pa = array_ops.placeholder(np.float32, [1, 4, 4, 1], name="a")
            pb = array_ops.placeholder(np.float32, [1, 2, 2, 1], name="b")
            c = nn.max_pool(pa,
                            ksize=[1, 2, 2, 1],
                            strides=[1, 2, 2, 1],
                            data_format='NCHW',
                            padding='SAME')
            d = gen_nn_ops.max_pool_grad(pa,
                                         c,
                                         pb,
                                         ksize=[1, 2, 2, 1],
                                         strides=[1, 2, 2, 1],
                                         data_format='NCHW',
                                         padding='SAME')

        with ops.device('cpu'):
            report = gen_ipu_ops.ipu_event_trace()

        tu.configure_ipu_system()

        with tu.ipu_session() as sess:
            sess.run(report)
            fe = {
                pa: input,
                pb: output_grad,
            }
            output, input_grad = sess.run((c, d), fe)
            self.assertAllClose(output, [[[[5.], [7.]], [[13.], [15.]]]])
            self.assertAllClose(
                input_grad,
                [[[[0.], [0.], [0.], [0.]], [[0.], [0.1], [0.], [0.1]],
                  [[0.], [0.], [0.], [0.]], [[0.], [0.1], [0.], [0.1]]]])

            result = sess.run(report)
            self.assertTrue(len(result) == 3)

            s = tu.extract_all_strings_from_event_trace(result)
            cs_list = tu.get_compute_sets_from_report(s)

            ok = [
                '__seed*', 'Copy_*', 'MaxPool/custom-call*/maxPool2x2/',
                'MaxPoolGrad/custom-call*/maxPool2x2'
            ]
            self.assertTrue(tu.check_all_compute_sets_and_list(cs_list, ok))
示例#15
0
 def testDirectUseOverlapping(self):
     for num_batches in [1, 3]:
         for row_window_size in [2, 5]:
             for col_window_size in [2, 4]:
                 num_rows = (row_window_size - 1) * 5 + 1
                 num_cols = (col_window_size - 1) * 7 + 1
                 for num_channels in [1, 2]:
                     input_shape = (num_batches, num_rows, num_cols,
                                    num_channels)
                     with self.test_session() as _:
                         input_tensor = constant_op.constant(
                             self._GenerateUniqueRandomInputTensor(
                                 input_shape))
                         window_size = [
                             1, row_window_size, col_window_size, 1
                         ]
                         stride_size = [
                             1, row_window_size - 1, col_window_size - 1, 1
                         ]
                         padding = "VALID"
                         output_tensor = nn_ops.max_pool(
                             input_tensor, window_size, stride_size,
                             padding)
                         output_data = output_tensor.eval()
                         output_backprop = self._PRNG.randint(
                             100, size=output_data.shape)
                         input_backprop_tensor = gen_nn_ops.max_pool_grad(
                             input_tensor, output_tensor, output_backprop,
                             window_size, stride_size, padding)
                         input_backprop = input_backprop_tensor.eval()
                         row_seq = list(
                             range(0, num_rows, row_window_size - 1))
                         col_seq = list(
                             range(0, num_cols, col_window_size - 1))
                         row_seq[-1] += 1
                         col_seq[-1] += 1
                         fmp_input_backprop_tensor = gen_nn_ops.fractional_max_pool_grad(
                             input_tensor,
                             output_tensor,
                             output_backprop,
                             row_seq,
                             col_seq,
                             overlapping=True)
                         fmp_input_backprop = fmp_input_backprop_tensor.eval(
                         )
                         self.assertShapeEqual(input_backprop,
                                               fmp_input_backprop_tensor)
                         self.assertAllClose(input_backprop,
                                             fmp_input_backprop)
示例#16
0
def fprop_pool(F, X, strides=None, ksize=None, padding='SAME'):
    #Propagate over pool layer
    xshape = X.get_shape().as_list()
    fshape = F.get_shape().as_list()
    if len(xshape) != len(fshape):
        F = tf.reshape(F, (-1, int(np.ceil(xshape[1]/2.0)), 
                               int(np.ceil(xshape[2]/2.0)), xshape[3]))
    ksize = [1, 2, 2, 1]  if ksize is None else ksize
    strides = [1, 2, 2, 1]  if strides is None else strides

    Z = tf.nn.max_pool(X, strides=strides, ksize=ksize, padding=padding) + 1e-9
    S = F / Z
    C = gen_nn_ops.max_pool_grad(X, Z, S, ksize, strides, padding)    
    F = X*C
    return F
 def test_on_tf(sess):
     grad_t = tf.transpose(grad, NCHW_TO_NHWC)
     ksize = self.ksize_nhwc
     strides = self.strides_nhwc
     input_t = np.transpose(self.input_nchw, NCHW_TO_NHWC)
     output_t = np.transpose(output, NCHW_TO_NHWC)
     b = max_pool_grad(input_t,
                       output_t,
                       grad_t,
                       ksize,
                       strides,
                       padding=padding,
                       data_format="NHWC")
     b = tf.transpose(b, NHWC_TO_NCHW)
     return sess.run(b, feed_dict={grad: g_nchw})
def ng_model(padding):
    orig_in = tf.placeholder(tf.float32, shape=[N, C, H, W])
    if padding == "VALID":
        grad = tf.placeholder(tf.float32, shape=valid_shape)
        orig_out = tf.placeholder(tf.float32, shape=valid_shape)
    elif padding == "SAME":
        grad = tf.placeholder(tf.float32, shape=same_shape)
        orig_out = tf.placeholder(tf.float32, shape=same_shape)

    out = max_pool_grad(orig_in,
                        orig_out,
                        grad,
                        ksize_nchw,
                        stride_nchw,
                        padding=padding,
                        data_format="NCHW")
    return out, orig_in, orig_out, grad
def unpool1d(origin_name, pool_value, k=2, stride=2):
    """
    :param origin_name: A string point to 3D tensor with shape [M, T, D]
                      contain argmax indices
    :param pool_value: A 3D tensor with shape [M, T//stride, D]
    :return: unpooling_value: A 3D tensor with shape [M, T, D]
    """
    origin_name += ":0"
    mask_value = tf.get_default_graph().get_tensor_by_name(origin_name)
    mask_value = tf.expand_dims(mask_value, axis=2)
    pool_value = tf.expand_dims(pool_value, axis=2)
    k_sizes = [1, k, 1, 1]
    strides = [1, stride, 1, 1]
    unpool = gen_nn_ops.max_pool_grad(mask_value, pool_value, pool_value,
                                      k_sizes, strides, 'VALID')
    unpool = tf.squeeze(unpool, axis=2)
    return unpool
def backprop_pool(activation,
                  relevance,
                  ksize,
                  strides,
                  pooling_type,
                  padding='VALID'):
    if pooling_type.lower() is 'avg':  # avg pooling
        z = nn_ops.avg_pool(activation, ksize, strides, padding) + 1e-10
        s = relevance / z
        c = gen_nn_ops._avg_pool_grad(tf.shape(activation), s, ksize, strides,
                                      padding)
        return activation * c
    else:  # max pooling
        z = nn_ops.max_pool(activation, ksize, strides, padding) + 1e-10
        s = relevance / z
        c = gen_nn_ops.max_pool_grad(activation, z, s, ksize, strides, padding)
        return activation * c
示例#21
0
    def _non_maxima_suppression(self, image):

        use_smoothing = False
        if use_smoothing:
            smooth_k = InterestFilter.smooth_kernel(self._hparams.nms_size,
                                                    self._hparams.nms_std)
            smooth_k = smooth_k[:, :, tf.newaxis, tf.newaxis]
            image = tf.nn.conv2d(image,
                                 smooth_k,
                                 strides=[1, 1, 1, 1],
                                 padding='SAME')
            return image

        pool_size = self._hparams.nms_size
        stride = self._hparams.nms_stride
        padding = 'SAME'

        strides = [1, stride, stride, 1]
        ksize = [1, pool_size, pool_size, 1]

        # image, indices = tf.nn.max_pool_with_argmax(image, ksize=ksize, strides=strides, padding=padding)
        #
        # self.scalar_vals.append(('indices_max', tf.reduce_max(indices)))
        # self.scalar_vals.append(('image_max', tf.reduce_max(image)))
        #
        # print('non maxima pooled image: ', image)
        # print('non maxima indices: ', indices)
        #
        # image = layer_utils.unpool_2d(image, indices, strides)
        # print('non maxima unpooled image: ', image)

        # The unpooling output is also the gradient of the pooling operation
        # So do pool and extract unpool from grads:
        # https://assiaben.github.io/posts/2018-06-tf-unpooling/
        img_op = image
        pool_op = tf.nn.max_pool(img_op,
                                 ksize=ksize,
                                 strides=strides,
                                 padding=padding,
                                 name='pool')
        unpool_op = gen_nn_ops.max_pool_grad(img_op, pool_op, pool_op, ksize,
                                             strides, padding)
        image = unpool_op

        return image
 def test_nhwc(self, padding):
     strides = self.strides_nhwc
     ksize = self.ksize_nhwc
     output = self.output_nhwc[padding]
     g_nhwc = self.grad_nhwc[padding]
     if padding == "VALID":
         grad = tf.placeholder(tf.float32, shape=(128, 112, 74, 3))
     elif padding == "SAME":
         grad = tf.placeholder(tf.float32, shape=(128, 112, 75, 3))
     out = max_pool_grad(self.input_nhwc,
                         output,
                         grad,
                         ksize,
                         strides,
                         padding=padding,
                         data_format="NHWC")
     sess_fn = lambda sess: sess.run(out, feed_dict={grad: g_nhwc})
     assert (np.allclose(self.with_ngraph(sess_fn),
                         self.without_ngraph(sess_fn),
                         rtol=5e-7))
 def _simple_lrp(self, R):
     '''
     LRP according to Eq(56) in DOI: 10.1371/journal.pone.0130140
     '''
     self.check_shape(R)
     if self.R.shape[1] == 1:
         self.R = tf.reshape(self.R, [self.batch_size, 7, 7, 512])
     Z = tf.nn.max_pool(self.input_tensor,
                        ksize=self.pool_kernel,
                        strides=self.pool_stride,
                        padding='SAME') + 1e-9
     S = self.R / Z
     C = gen_nn_ops.max_pool_grad(self.input_tensor,
                                  Z,
                                  S,
                                  ksize=self.pool_kernel,
                                  strides=self.pool_stride,
                                  padding='SAME')
     result = self.input_tensor * C
     return result
示例#24
0
  def testFwdAndBwdMaxPool(self):
    with self.session() as sess:
      input_values = np.arange(16).reshape(1, 4, 4, 1)
      output_grad = np.full((1, 2, 2, 1), 0.1)

      with ops.device("/device:IPU:0"):
        pa = array_ops.placeholder(np.float32, [1, 4, 4, 1], name="a")
        pb = array_ops.placeholder(np.float32, [1, 2, 2, 1], name="b")
        c = nn.max_pool(pa,
                        ksize=[1, 2, 2, 1],
                        strides=[1, 2, 2, 1],
                        data_format='NCHW',
                        padding='SAME')
        d = gen_nn_ops.max_pool_grad(pa,
                                     c,
                                     pb,
                                     ksize=[1, 2, 2, 1],
                                     strides=[1, 2, 2, 1],
                                     data_format='NCHW',
                                     padding='SAME')

      report = tu.ReportJSON(self, sess)
      report.reset()

      fe = {
          pa: input_values,
          pb: output_grad,
      }
      output, input_grad = sess.run((c, d), fe)
      self.assertAllClose(output, [[[[5.], [7.]], [[13.], [15.]]]])
      self.assertAllClose(
          input_grad, [[[[0.], [0.], [0.], [0.]], [[0.], [0.1], [0.], [0.1]],
                        [[0.], [0.], [0.], [0.]], [[0.], [0.1], [0.], [0.1]]]])

      report.parse_log(assert_len=4)

      ok = [
          '__seed*', 'Copy_*', 'MaxPool/max-pool*/maxPool2x2/',
          'MaxPoolGrad/max-pool-grad*/maxPool2x2'
      ]
      report.assert_all_compute_sets_and_list(ok)
 def testDirectUseOverlapping(self):
   for num_batches in [1, 3]:
     for row_window_size in [2, 5]:
       for col_window_size in [2, 4]:
         num_rows = (row_window_size - 1) * 5 + 1
         num_cols = (col_window_size - 1) * 7 + 1
         for num_channels in [1, 2]:
           input_shape = (num_batches, num_rows, num_cols, num_channels)
           with self.test_session() as _:
             input_tensor = constant_op.constant(
                 self._GenerateUniqueRandomInputTensor(input_shape))
             window_size = [1, row_window_size, col_window_size, 1]
             stride_size = [1, row_window_size - 1, col_window_size - 1, 1]
             padding = "VALID"
             output_tensor = nn_ops.max_pool(input_tensor, window_size,
                                             stride_size, padding)
             output_data = output_tensor.eval()
             output_backprop = self._PRNG.randint(100, size=output_data.shape)
             input_backprop_tensor = gen_nn_ops.max_pool_grad(
                 input_tensor, output_tensor, output_backprop, window_size,
                 stride_size, padding)
             input_backprop = input_backprop_tensor.eval()
             row_seq = list(range(0, num_rows, row_window_size - 1))
             col_seq = list(range(0, num_cols, col_window_size - 1))
             row_seq[-1] += 1
             col_seq[-1] += 1
             fmp_input_backprop_tensor = gen_nn_ops.fractional_max_pool_grad(
                 input_tensor,
                 output_tensor,
                 output_backprop,
                 row_seq,
                 col_seq,
                 overlapping=True)
             fmp_input_backprop = fmp_input_backprop_tensor.eval()
             self.assertShapeEqual(input_backprop, fmp_input_backprop_tensor)
             self.assertAllClose(input_backprop, fmp_input_backprop)
示例#26
0
    def makeOps(self, graph):
        scratch, scratch2, rsz = graph.scratch, graph.scratch2, graph.rsz

        ##############################################################################
        ########################## Build forward ops
        ##############################################################################
        fsave = []
        out = self.prev.ftop

        ######### Handle res_in
        if self.nldef.res_in:
            assert self.prev.back
            rtop = tf.reshape(scratch2[:np.prod(rsz)], rsz)
            rchange = self.inSz == rsz
            if rchange:
                out = out + rtop
            else:
                rstride = rsz[1] // self.inSz[1]
                reskern = tf.ones([rstride, rstride, rsz[-1], 1],
                                  dtype=DT) / np.float32(rstride**2)
                respad = (self.inCh - rsz[-1]) // 2
                rt = rtop
                if rstride > 1:
                    rt = tf.nn.depthwise_conv2d(rt, reskern,
                                                [1, rstride, rstride, 1],
                                                'VALID')
                if respad > 0:
                    rt = tf.pad(rt, [[0, 0], [0, 0], [0, 0], [respad, respad]])
                out = out + rt
        out0 = out

        ### BN
        if self.nldef.bn:
            assert self.prev.back
            mu, vr = tf.nn.moments(out, [0, 1, 2])
            bnfac = tf.sqrt(vr + BNEPS)
            out = (out - mu) / bnfac
            self.bnfac = tf.Variable(tf.zeros_like(bnfac))
            fsave += [tf.assign(self.bnfac, bnfac).op]

            scale = tf.Variable(tf.ones([self.inCh], dtype=DT))
            sgrad = tf.Variable(tf.zeros([self.inCh], dtype=DT))
            graph.weights.append(scale)
            graph.grads.append(sgrad)
            cscale = tf.maximum(1e-8, scale)
            out = out * cscale

        ### Bias
        if self.prev.back:
            bias = tf.Variable(tf.zeros([self.inCh], dtype=DT))
            bgrad = tf.Variable(tf.zeros([self.inCh], dtype=DT))
            graph.weights.append(bias)
            graph.grads.append(bgrad)
            out = out + bias

        ### Save + Handle ReLUs
        self.btop = None

        # If has max-pool
        if self.nldef.maxpool:
            assert self.nldef.res_out != 1
            var = tf.Variable(tf.zeros(out.get_shape(), dtype=DT))
            fsave += [tf.assign(var, out).op]
            self.btop0 = (var - bias) / cscale
            self.btop1 = self.btop0

            self.Rm = tf.cast(var > 0, dtype=DT)
            btop = tf.nn.relu(var)
            self.premp = btop
            self.btop = tf.nn.max_pool(btop, [1, 3, 3, 1], [1, 2, 2, 1],
                                       'SAME')

            out = tf.nn.relu(out)
            out = tf.nn.max_pool(out, [1, 3, 3, 1], [1, 2, 2, 1], 'SAME')

        # Last layer
        elif self == graph.layers[-1]:
            fsave += [
                tf.assign(scratch2[:np.prod(self.inSz)], tf.reshape(out,
                                                                    [-1])).op
            ]
            var = tf.reshape(scratch2[:np.prod(self.inSz)], out.get_shape())

        # Quantization
        elif self.nldef.bn and (graph.qtype == 4 or graph.qtype == 8):
            assert self.nldef.relu

            sOp, outs, self.Rm = q.quant(graph.qtype, out / cscale,
                                         bias / cscale)
            fsave += sOp
            self.btop0 = outs - bias / cscale
            self.btop1 = self.btop0
            self.btop = tf.nn.relu(outs * cscale)
            out = tf.nn.relu(out)

        # No Quantization
        else:
            var = tf.Variable(tf.zeros(out.get_shape(), dtype=DT))
            fsave += [tf.assign(var, out).op]

        if self.btop is None:
            if self.nldef.bn:
                self.btop0 = (var - bias) / cscale
                self.btop1 = self.btop0

            if self.nldef.relu:
                self.btop = tf.nn.relu(var)
                self.Rm = tf.cast(var > 0, dtype=DT)
                out = tf.nn.relu(out)
            else:
                self.btop = var

        ######### Handle res_out
        if self.nldef.res_out is not None:
            graph.rsz = out.get_shape().as_list()
            sidx = np.prod(graph.rsz)
            if self.nldef.res_out == 1:
                fsave += [
                    tf.assign(scratch2[:sidx], tf.reshape(out0, [-1])).op
                ]
            else:
                fsave += [tf.assign(scratch2[:sidx], tf.reshape(out, [-1])).op]

        ########### Do the actual convolution
        kshp = [self.ksz, self.ksz, self.inCh, self.outCh]
        if self == graph.layers[-1]:
            sq = np.sqrt(1.0 / np.float32(self.ksz * self.ksz * self.inCh))
        else:
            sq = np.sqrt(2.0 / np.float32(self.ksz * self.ksz * self.inCh))

        kernel = tf.random_normal(kshp, stddev=sq, dtype=DT)
        kernel = tf.Variable(kernel)
        kgrad = tf.Variable(tf.zeros(kshp, dtype=DT))
        graph.weights.append(kernel)
        graph.grads.append(kgrad)

        if self.nldef.avpool == True:
            out = tf.reduce_mean(out, [1, 2], True)
        out = tf.nn.conv2d(out, kernel, [1, self.stride, self.stride, 1],
                           self.pad)

        ########### Store output in scratch
        fsave += [
            tf.assign(scratch[:np.prod(self.oSz)], tf.reshape(out, [-1])).op
        ]
        self.fOp = tf.group(*fsave)
        self.ftop = tf.reshape(scratch[:np.prod(self.oSz)], self.oSz)

        ##############################################################################
        ########################## Build Backward ops
        ##############################################################################
        ingrad = self.ftop  # Same shape loading from scratch
        bsave = []

        inp = self.btop
        if self.nldef.avpool:
            inp = tf.reduce_mean(inp, [1, 2], True)

        kg = tf.nn.conv2d_backprop_filter(inp, kshp, ingrad,
                                          [1, self.stride, self.stride, 1],
                                          self.pad)
        kg += graph.WD * kernel
        bsave += [tf.assign(kgrad, kg).op]

        if not self.prev.back:
            self.bOp = tf.group(*bsave)
            return

        if self.nldef.avpool:
            ingrad = tf.nn.conv2d_backprop_input(
                [self.inSz[0], 1, 1, self.inSz[3]], kernel, ingrad,
                [1, 1, 1, 1], 'VALID') / np.float32(
                    self.inSz[1] * self.inSz[2])
        elif self.nldef.maxpool:
            ingrad = tf.nn.conv2d_backprop_input([
                self.inSz[0], self.inSz[1] // 2, self.inSz[2] // 2,
                self.inSz[3]
            ], kernel, ingrad, [1, self.stride, self.stride, 1], self.pad)
        else:
            ingrad = tf.nn.conv2d_backprop_input(
                self.inSz, kernel, ingrad, [1, self.stride, self.stride, 1],
                self.pad)
        if self.nldef.res_out == 2:
            gshp = ingrad.get_shape().as_list()
            ingrad += tf.reshape(graph.scratch2[:np.prod(gshp)], gshp)

        if self.nldef.maxpool:
            ingrad = max_pool_grad(self.premp, self.btop, ingrad, [1, 3, 3, 1],
                                   [1, 2, 2, 1], 'SAME')

        if self.nldef.relu:
            ingrad *= self.Rm
        bsave += [tf.assign(bgrad, tf.reduce_sum(ingrad, [0, 1, 2])).op]
        if self.nldef.bn:
            bsave += [
                tf.assign(sgrad, tf.reduce_sum(ingrad * self.btop0,
                                               [0, 1, 2])).op
            ]
            ingrad = ingrad * cscale
            ingrad = ingrad - tf.reduce_mean(ingrad, [0, 1, 2])
            ingrad -= self.btop0 * tf.reduce_mean(ingrad * self.btop1,
                                                  [0, 1, 2])
            ingrad /= self.bnfac
        if self.nldef.res_out == 1:
            ingrad += tf.reshape(graph.scratch2[:np.prod(self.inSz)],
                                 self.inSz)

        bsave += [
            tf.assign(graph.scratch[:np.prod(self.inSz)],
                      tf.reshape(ingrad, [-1])).op
        ]

        if self.nldef.res_in:
            if rchange:
                bsave += [
                    tf.assign(graph.scratch2[:np.prod(self.inSz)],
                              tf.reshape(ingrad, [-1])).op
                ]
            else:
                if respad > 0:
                    ingrad = ingrad[:, :, :, respad:-respad]
                if rstride > 1:
                    ingrad = tf.nn.depthwise_conv2d_native_backprop_input(
                        rsz, reskern, ingrad, [1, rstride, rstride, 1],
                        'VALID')
                bsave += [
                    tf.assign(graph.scratch2[:np.prod(rsz)],
                              tf.reshape(ingrad, [-1])).op
                ]

        self.bOp = tf.group(*bsave)