Beispiel #1
0
def _sparse_res_block_with_mask(x, ksize_list, block_params, strides, ind_init,
                                bin_init):
    """Sparse conv 2d with mask."""
    ind_obj = ReduceMask(active_block_indices=ind_init, bin_counts=bin_init)
    y_ = sparse_res_block_bottleneck(x,
                                     ksize_list,
                                     ind_obj,
                                     block_params,
                                     strides,
                                     True,
                                     use_var=True,
                                     data_format='NCHW')
    return y_
Beispiel #2
0
    def _test_sparse_resblock_gradients(self,
                                        xval,
                                        maskval,
                                        bsize,
                                        strides,
                                        padding,
                                        data_format='NHWC',
                                        dynamic_size=False):
        with tf.Graph().as_default() as g:
            x = tf.constant(xval)
            mask = tf.constant(maskval)
            ch_in = xval.shape[3]
            ch_out = xval.shape[3] // 4
            ksize_list = [[1, 1, ch_in, ch_out], [3, 3, ch_out, ch_out],
                          [1, 1, ch_out, ch_in]]
            if dynamic_size:
                blk_params = calc_block_params_res_block(
                    tf.shape(xval), bsize, ksize_list, strides, padding)
            else:
                blk_params = calc_block_params_res_block(
                    xval.shape, bsize, ksize_list, strides, padding)
            ind = convert_mask_to_indices_custom(mask, blk_params, 0.)
            ReduceMask = namedtuple('ReduceMask',
                                    ['active_block_indices', 'bin_counts'])
            ind.active_block_indices.set_shape([27, 3])
            ind.bin_counts.set_shape([1])
            ind_var = tf.Variable(ind.active_block_indices, trainable=False)
            bin_var = tf.Variable(ind.bin_counts, trainable=False)
            ind_fixed = ReduceMask(active_block_indices=ind_var,
                                   bin_counts=bin_var)
            tf_ind = convert_mask_to_indices_custom(mask, blk_params, 0.)
            with self.test_session() as sess:
                py_inds = sess.run([tf_ind])
            ind = lambda: 0
            ind.bin_counts = tf.constant(py_inds[0].bin_counts)
            ind.active_block_indices = tf.constant(
                py_inds[0].active_block_indices)

            y = sparse_res_block_bottleneck(x,
                                            ksize_list,
                                            ind_fixed,
                                            blk_params,
                                            strides,
                                            is_training=True,
                                            data_format=data_format,
                                            w_project=None,
                                            no_activation=False,
                                            use_var=False)
            trainable_vars = tf.trainable_variables()
            print('')
            print('-' * 55)
            print('Sparse Residual')
            print('{:30s} {:>10s} {:>10s}'.format('name', 'grad angle',
                                                  'abs err'))
            with self.test_session() as sess:
                sess.run(tf.global_variables_initializer())
                yval = y.eval()
                err = compute_gradient_angle(x,
                                             xval.shape,
                                             y,
                                             yval.shape,
                                             x_init_value=xval)
                err2 = compute_gradient_abs_error(x,
                                                  xval.shape,
                                                  y,
                                                  yval.shape,
                                                  x_init_value=xval)
                print('{:30s} {:>10.3f} {:>10.3f}'.format('x', err, err2))

                #'sub3/bn3/batchnorm/add_1:0',
                for name in [
                        'SparseScatter:0', 'SparseGather:0',
                        'sub3/bn3/FusedBatchNorm:0', 'sub3/conv3/Conv2D:0',
                        'sub3/relu3:0', 'sub2/conv2/Conv2D:0', 'sub2/relu2:0',
                        'sub2/bn2/FusedBatchNorm:0', 'sub1/conv1/Conv2D:0',
                        'sub1/relu1:0', 'sub1/bn1/FusedBatchNorm:0'
                ]:
                    act = g.get_tensor_by_name(name)
                    actval = act.eval()
                    err = compute_gradient_angle(act,
                                                 actval.shape,
                                                 y,
                                                 yval.shape,
                                                 x_init_value=actval)
                    err2 = compute_gradient_abs_error(act,
                                                      actval.shape,
                                                      y,
                                                      yval.shape,
                                                      x_init_value=actval)
                    print('{:30s} {:>10.3f} {:>10.3f}'.format(name, err, err2))

                for vv in trainable_vars:
                    vvval = vv.eval()
                    err = compute_gradient_angle(vv,
                                                 vvval.shape,
                                                 y,
                                                 yval.shape,
                                                 x_init_value=vvval)
                    err2 = compute_gradient_abs_error(vv,
                                                      vvval.shape,
                                                      y,
                                                      yval.shape,
                                                      x_init_value=vvval)
                    print('{:30s} {:>10.3f} {:>10.3f}'.format(
                        vv.name, err, err2))