def rpn_head(featuremap): with tf.variable_scope('rpn'), \ argscope(Conv2D, data_format='NCHW', W_init=tf.random_normal_initializer(stddev=0.01)): hidden = Conv2D('conv0', featuremap, 1024, 3, nl=tf.nn.relu) label_logits = Conv2D('class', hidden, config.NR_ANCHOR, 1) box_logits = Conv2D('box', hidden, 4 * config.NR_ANCHOR, 1) # 1, NA(*4), im/16, im/16 (NCHW) label_logits = tf.transpose(label_logits, [0, 2, 3, 1]) # 1xfHxfWxNA label_logits = tf.squeeze(label_logits, 0) # fHxfWxNA shp = tf.shape(box_logits) # 1x(NAx4)xfHxfW box_logits = tf.transpose(box_logits, [0, 2, 3, 1]) # 1xfHxfWx(NAx4) box_logits = tf.reshape(box_logits, tf.stack([shp[2], shp[3], config.NR_ANCHOR, 4])) # fHxfWxNAx4 return label_logits, box_logits
def maskrcnn_upXconv_head(feature, num_category, num_convs, norm=None): """ Args: feature (N x C x s x s): size is 7 in C4 models and 14 in FPN models. num_category(int): num_convs (int): number of convolution layers norm (str or None): either None or 'GN' Returns: mask_logits (N x num_category x 2s x 2s): """ assert norm in [None, 'GN'], norm l = feature with argscope( [Conv2D, Conv2DTranspose], data_format='channels_last', kernel_initializer=tf.variance_scaling_initializer( scale=2.0, mode='fan_out', distribution='untruncated_normal' if get_tf_version_tuple() >= (1, 12) else 'normal')): # c2's MSRAFill is fan_out for k in range(num_convs): l = Conv2D('fcn{}'.format(k), l, cfg.MRCNN.HEAD_DIM, 3, activation=tf.nn.relu) if norm is not None: l = GroupNorm('gn{}'.format(k), l) l = Conv2DTranspose('deconv', l, cfg.MRCNN.HEAD_DIM, 2, strides=2, activation=tf.nn.relu) l = Conv2D( 'conv', l, num_category, 1, kernel_initializer=tf.random_normal_initializer(stddev=0.001)) return l
def encoder(i, freeze): """ Pre-activated ResNet50 Encoder """ d1 = Conv2D('conv0', i, 64, 7, padding='valid', strides=1, activation=BNReLU) d1 = res_blk('group0', d1, [ 64, 64, 256], [1, 3, 1], 3, strides=1, freeze=freeze) d2 = res_blk('group1', d1, [128, 128, 512], [1, 3, 1], 4, strides=2, freeze=freeze) d2 = tf.stop_gradient(d2) if freeze else d2 d3 = res_blk('group2', d2, [256, 256, 1024], [1, 3, 1], 6, strides=2, freeze=freeze) d3 = tf.stop_gradient(d3) if freeze else d3 d4 = res_blk('group3', d3, [512, 512, 2048], [1, 3, 1], 3, strides=2, freeze=freeze) d4 = tf.stop_gradient(d4) if freeze else d4 d4 = Conv2D('conv_bot', d4, 1024, 1, padding='same') return [d1, d2, d3, d4]
def encoder_blk(name, feat_in, num_feats, has_down=False): with tf.variable_scope(name): feat = feat_in if not has_down else MaxPooling( 'pool1', feat_in, 2, strides=2, padding='same') feat = Conv2D('conv_1', feat, num_feats, 3, padding='valid', strides=1, activation=tf.nn.relu) feat = Conv2D('conv_2', feat, num_feats, 3, padding='valid', strides=1, activation=tf.nn.relu) return feat
def net(name, i, basis_filter_list, rot_matrix_list, nr_orients, filter_type, is_training): """ Dense Steerable Filter CNN """ dense_basis_list = [basis_filter_list[0],basis_filter_list[1]] dense_rot_list = [rot_matrix_list[0], rot_matrix_list[1]] with tf.variable_scope(name): c1 = GConv2D('ds_conv1', i, 8, 7, nr_orients, filter_type, basis_filter_list[1], rot_matrix_list[1], input_layer=True) c2 = GConv2D('ds_conv2', c1, 8, 7, nr_orients, filter_type, basis_filter_list[1], rot_matrix_list[1]) p1 = MaxPooling('max_pool1', c2, 2) #### d1 = g_dense_blk('dense1', p1, [32,8], [5,7], 2, nr_orients, filter_type, dense_basis_list, dense_rot_list, bn_init=False) c3 = GConv2D('ds_conv3', d1, 32, 5, nr_orients, filter_type, basis_filter_list[0], rot_matrix_list[0]) p2 = MaxPooling('max_pool2', c3, 2, padding= 'valid') #### d2 = g_dense_blk('dense2', p2, [32,8], [5,7], 2, nr_orients, filter_type, dense_basis_list, dense_rot_list, bn_init=False) c4 = GConv2D('ds_conv4', d2, 32, 5, nr_orients, filter_type, basis_filter_list[0], rot_matrix_list[0]) p3 = MaxPooling('max_pool3', c4, 2, padding= 'valid') #### d3 = g_dense_blk('dense3', p3, [32,8], [5,7], 3, nr_orients, filter_type, dense_basis_list, dense_rot_list, bn_init=False) c5 = GConv2D('ds_conv5', d3, 32, 5, nr_orients, filter_type, basis_filter_list[0], rot_matrix_list[0]) p4 = MaxPooling('max_pool4', c5, 2, padding= 'valid') #### d4 = g_dense_blk('dense4', p4, [32,8], [5,7], 3, nr_orients, filter_type, dense_basis_list, dense_rot_list, bn_init=False) c6 = GConv2D('ds_conv6', d4, 32, 5, nr_orients, filter_type, basis_filter_list[0], rot_matrix_list[0]) p5 = AvgPooling('glb_avg_pool', c6, 6, padding= 'valid') p6 = GroupPool('orient_pool', p5, nr_orients, pool_type='max') #### c7 = Conv2D('conv3', p6, 96, 1, use_bias=True, nl=BNReLU) c7 = tf.layers.dropout(c7, rate=0.3, seed=5, training=is_training) c8 = Conv2D('conv4', c7, 96, 1, use_bias=True, nl=BNReLU) c8 = tf.layers.dropout(c8, rate=0.3, seed=5, training=is_training) return c8
def LinearBottleneck(x, ich, och, kernel, padding='SAME', stride=1, activation=BNPReLU, t=3, w_init=None): ''' mobilenetv2 linear bottlenet. ''' if active is None: active = True if kernel > 3 else False out = Conv2D(_get_conv_name(), x, int(ich*t), 1, activation=BNPReLU) out = DWConv(_get_dwconv_name(), out, kernel, padding, stride, w_init, activation=activation) out = Conv2D(_get_conv_name(), out, och, 1, activation=BNonly) if stride != 1: return out if ich != och: x = Conv2D(_get_conv_name(), x, int(och), 1, activation=BNonly) return x + out
def resnet_bottleneck(x, ch_out, stride, stride_first=False): """ stride_first: originax resnet put stride on first conv. fb.resnet.torch put stride on second conv. # noqa """ shortcut = x x = Conv2D('conv1', x, ch_out, 1, stride=stride if stride_first else 1, nl=BNReLU) x = Conv2D('conv2', x, ch_out, 3, stride=1 if stride_first else stride, nl=BNReLU) x = Conv2D('conv3', x, ch_out * 4, 1, nl=get_bn(zero_init=True)) return x + resnet_shortcut( shortcut, ch_out * 4, stride, nl=get_bn(zero_init=False))
def resnet_bottleneck(l, ch_out, stride): l, shortcut = l, l l = Conv2D('conv1', l, ch_out, 1, activation=BNReLU) if stride == 2: l = tf.pad(l, [[0, 0], [0, 0], maybe_reverse_pad(0, 1), maybe_reverse_pad(0, 1)]) l = Conv2D('conv2', l, ch_out, 3, strides=2, activation=BNReLU, padding='VALID') else: l = Conv2D('conv2', l, ch_out, 3, strides=stride, activation=BNReLU) l = Conv2D('conv3', l, ch_out * 4, 1, activation=get_bn(zero_init=True)) ret = l + resnet_shortcut( shortcut, ch_out * 4, stride, activation=get_bn(zero_init=False)) return tf.nn.relu(ret, name='output')
def resnet_bottleneck(l, ch_out, stride, stride_first=False): """ stride_first: original resnet put stride on first conv. fb.resnet.torch put stride on second conv. """ shortcut = l l = Conv2D('conv1', l, ch_out, 1, strides=stride if stride_first else 1, activation=BNReLU) l = Grconv('conv2', l, ch_out, 3, strides=1 if stride_first else stride, activation=BNReLU) l = Conv2D('conv3', l, ch_out * 4, 1, activation=get_bn(zero_init=True)) return l + resnet_shortcut( shortcut, ch_out * 4, stride, activation=get_bn(zero_init=False))
def conv_with_rn(gradient): out = Conv2D( 'conv', gradient, gradient.get_shape()[3], 1, strides=1, activation=get_rn(), kernel_initializer=tf.contrib.layers.variance_scaling_initializer(2.0)) gradient = gradient + out return gradient
def resnet_shortcut(l, n_out, stride, activation=tf.identity): n_in = l.get_shape().as_list()[1] if n_in != n_out: return Conv2D('convshortcut', l, n_out, 1, strides=stride, activation=activation) else: return l
def resnet_shortcut(l, n_out, stride, activation=tf.identity): n_in = l.get_shape().as_list()[1 if is_data_format_nchw() else 3] if n_in != n_out: return Conv2D('convshortcut', l, n_out, 1, stride=stride, activation=activation) else: return l
def resnet_shortcut(l, n_out, stride, activation=tf.identity): n_in = l.get_shape().as_list()[1] if n_in != n_out: # change dimension when channel is not the same return Conv2D('convshortcut', l, n_out, 1, strides=stride, activation=activation) else: return l
def denoising(name, l, embed=True, softmax=True): with tf.variable_scope(name): f = non_local_op(l, embed=embed, softmax=softmax) f = Conv2D('conv', f, l.shape[1], 1, strides=1, activation=get_bn(zero_init=True)) l = l + f return l
def resnet_shortcut(l, n_out, stride, activation=tf.identity): n_in = l.shape[1] if n_in != n_out: # change dimension when channel is not the same # TF's SAME mode output ceil(x/stride), which is NOT what we want when x is odd and stride is 2 # In FPN mode, the images are pre-padded already. if not cfg.MODE_FPN and stride == 2: l = l[:, :, :-1, :-1] return Conv2D('convshortcut', l, n_out, 1, strides=stride, activation=activation) else: return l
def res_blk(name, l, ch, ksize, count, split=1, strides=1, freeze=False): ########################################## #Resnet50 block ########################################## ch_in = l.get_shape().as_list() with tf.variable_scope(name): for i in range(0, count): with tf.variable_scope('block' + str(i)): x = l if i == 0 else BNReLU('preact', l) x = Conv2D('conv1', x, ch[0], ksize[0], activation=BNReLU) x = Conv2D('conv2', x, ch[1], ksize[1], split=split, strides=strides if i == 0 else 1, activation=BNReLU) x = Conv2D('conv3', x, ch[2], ksize[2], activation=tf.identity) if (strides != 1 or ch_in[1] != ch[2]) and i == 0: l = Conv2D('convshortcut', l, ch[2], 1, strides=strides) x = tf.stop_gradient(x) if freeze else x l = l + x # end of each group need an extra activation l = BNReLU('bnlast',l) return l
def maskrcnn_upXconv_head(feature, num_class, num_convs): """ Args: feature (NxCx s x s): size is 7 in C4 models and 14 in FPN models. num_classes(int): num_category + 1 num_convs (int): number of convolution layers Returns: mask_logits (N x num_category x 2s x 2s): """ l = feature with argscope([Conv2D, Conv2DTranspose], data_format='channels_first', kernel_initializer=tf.variance_scaling_initializer( scale=2.0, mode='fan_out', distribution='normal')): # c2's MSRAFill is fan_out for k in range(num_convs): l = Conv2D('fcn{}'.format(k), l, config.MASKRCNN_HEAD_DIM, 3, activation=tf.nn.relu) l = Conv2DTranspose('deconv', l, config.MASKRCNN_HEAD_DIM, 2, strides=2, activation=tf.nn.relu) l = Conv2D('conv', l, num_class - 1, 1) return l
def non_local_op(l, embed, softmax): """ Feature Denoising, Sec 4.2 & Fig 5. Args: embed (bool): whether to use embedding on theta & phi softmax (bool): whether to use gaussian (softmax) version or the dot-product version. """ n_in, H, W = l.shape.as_list()[1:] if embed: theta = Conv2D( 'embedding_theta', l, n_in / 2, 1, strides=1, kernel_initializer=tf.random_normal_initializer(stddev=0.01)) phi = Conv2D( 'embedding_phi', l, n_in / 2, 1, strides=1, kernel_initializer=tf.random_normal_initializer(stddev=0.01)) g = l else: theta, phi, g = l, l, l if n_in > H * W or softmax: f = tf.einsum('niab,nicd->nabcd', theta, phi) if softmax: orig_shape = tf.shape(f) f = tf.reshape(f, [-1, H * W, H * W]) f = f / tf.sqrt(tf.cast(theta.shape[1], theta.dtype)) f = tf.nn.softmax(f) f = tf.reshape(f, orig_shape) f = tf.einsum('nabcd,nicd->niab', f, g) else: f = tf.einsum('nihw,njhw->nij', phi, g) f = tf.einsum('nij,nihw->njhw', f, theta) if not softmax: f = f / tf.cast(H * W, f.dtype) return tf.reshape(f, tf.shape(l))
def fpn_model(features): """ Args: features ([tf.Tensor]): ResNet features c2-c5 Returns: [tf.Tensor]: FPN features p2-p6 """ assert len(features) == 4, features num_channel = config.FPN_NUM_CHANNEL def upsample2x(name, x): return FixedUnPooling( name, x, 2, unpool_mat=np.ones((2, 2), dtype='float32'), data_format='channels_first') # tf.image.resize is, again, not aligned. # with tf.name_scope(name): # logger.info("Nearest neighbor") # shape2d = tf.shape(x)[2:] # x = tf.transpose(x, [0, 2, 3, 1]) # x = tf.image.resize_nearest_neighbor(x, shape2d * 2, align_corners=True) # x = tf.transpose(x, [0, 3, 1, 2]) # return x with argscope(Conv2D, data_format='channels_first', nl=tf.identity, use_bias=True, kernel_initializer=tf.variance_scaling_initializer(scale=1.)): lat_2345 = [Conv2D('lateral_1x1_c{}'.format(i + 2), c, num_channel, 1) for i, c in enumerate(features)] lat_sum_5432 = [] for idx, lat in enumerate(lat_2345[::-1]): if idx == 0: lat_sum_5432.append(lat) else: lat = lat + upsample2x('upsample_lat{}'.format(6 - idx), lat_sum_5432[-1]) lat_sum_5432.append(lat) p2345 = [Conv2D('posthoc_3x3_p{}'.format(i + 2), c, num_channel, 3) for i, c in enumerate(lat_sum_5432[::-1])] p6 = MaxPooling('maxpool_p6', p2345[-1], pool_size=1, strides=2, data_format='channels_first') return p2345 + [p6]
def fpn_model(features): """ Args: features ([tf.Tensor]): ResNet features c2-c5 Returns: [tf.Tensor]: FPN features p2-p6 """ assert len(features) == 4, features num_channel = 256 use_gn = config.NORM == 'GN' def upsample2x(name, x): return FixedUnPooling( name, x, 2, unpool_mat=np.ones((2, 2), dtype='float32'), data_format='channels_first') with argscope(Conv2D, data_format='channels_first', activation=tf.identity, use_bias=True, kernel_initializer=tf.variance_scaling_initializer(scale=1.)): lat_2345 = [Conv2D('lateral_1x1_c{}'.format(i + 2), c, num_channel, 1) for i, c in enumerate(features)] if use_gn: lat_2345 = [GroupNorm('gn_c{}'.format(i + 2), c) for i, c in enumerate(lat_2345)] lat_sum_5432 = [] for idx, lat in enumerate(lat_2345[::-1]): if idx == 0: lat_sum_5432.append(lat) else: lat = lat + tf.transpose(tf.image.resize_nearest_neighbor(tf.transpose(lat_sum_5432[-1], [0, 2, 3, 1]), size=tf.shape(lat)[-2:]), [0, 3, 1, 2]) #lat = lat + upsample2x('upsample_lat{}'.format(6 - idx), lat_sum_5432[-1]) lat_sum_5432.append(lat) p2345 = [Conv2D('posthoc_3x3_p{}'.format(i + 2), c, num_channel, 3) for i, c in enumerate(lat_sum_5432[::-1])] p6 = tf.pad(p2345[-1], [[0, 0], [0, 0], maybe_reverse_pad(0, 1), maybe_reverse_pad(0, 1)]) p6 = MaxPooling('maxpool_p6', p6, pool_size=3, strides=2, data_format='channels_first', padding='VALID') #p1 = tf.transpose(tf.image.resize_nearest_neighbor(tf.transpose(p2345[0], [0, 2, 3, 1]), size=tf.shape(p2345[0])[-2:]*2), [0, 3, 1, 2]) all_p = p2345 + [p6] return all_p[::-1]
def LinearBottleneck(x, ich, och, kernel, padding='SAME', stride=1, active=None, t=3, use_ab=False, w_init=None): ''' mobilenetv2 linear bottlenet. ''' if active is None: active = True if kernel > 3 else False out = Conv2D('conv_e', x, int(ich*t), 1, activation=BNReLU) if use_ab: out = AccuracyBoost('ab', out) out = DWConv('conv_d', out, kernel, padding, stride, w_init, active) out = Conv2D('conv_p', out, och, 1, activation=None) with tf.variable_scope('conv_p'): out = BatchNorm('bn', out) return out
def rpn_head(featuremap, channel, num_anchors): """ Returns: label_logits: fHxfWxNA box_logits: fHxfWxNAx4 """ with argscope(Conv2D, data_format='channels_first', kernel_initializer=tf.random_normal_initializer(stddev=0.01)): hidden = Conv2D('conv0', featuremap, channel, 3, activation=tf.nn.relu) label_logits = Conv2D('class', hidden, num_anchors, 1) box_logits = Conv2D('box', hidden, 4 * num_anchors, 1) # 1, NA(*4), im/16, im/16 (NCHW) label_logits = tf.transpose(label_logits, [0, 2, 3, 1]) # 1xfHxfWxNA label_logits = tf.squeeze(label_logits, 0) # fHxfWxNA shp = tf.shape(box_logits) # 1x(NAx4)xfHxfW box_logits = tf.transpose(box_logits, [0, 2, 3, 1]) # 1xfHxfWx(NAx4) box_logits = tf.reshape(box_logits, tf.stack([shp[2], shp[3], num_anchors, 4])) # fHxfWxNAx4 return label_logits, box_logits
def DownsampleBottleneck(x, ich, och, kernel, padding='SAME', stride=2, activation=None, t=3, use_ab=False, w_init=None): ''' downsample linear bottlenet. ''' if activation is None: activation = BNReLU if kernel > 3 else BNOnly out_e = Conv2D('conv_e', x, ich*t, 1, activation=BNReLU) if use_ab: out_e = AccuracyBoost('ab', out_e) out_d = DWConv('conv_d', out_e, kernel, padding, stride, w_init, activation) out_m = DWConv('conv_m', out_e, kernel, padding, stride, w_init, activation) out = tf.concat([out_d, out_m], axis=-1) out = Conv2D('conv_p', out, och, 1, activation=BNOnly) return out
def resnet_shortcut(l, n_out, stride, nl=tf.identity): #data_format = get_arg_scope()['Conv2D']['data_format'] n_in = l.get_shape().as_list()[3] if n_in != n_out: # change dimension when channel is not the same return Conv2D('convshortcut', l, n_out, 1, stride=stride, nl=nl) else: return l
def resnet_bottleneck(l, ch_out, stride, group=1, res2_bottleneck=64): """ Args: group (int): the number of groups for resnext res2_bottleneck (int): the number of channels in res2 bottleneck. The default corresponds to ResNeXt 1x64d, i.e. vanilla ResNet. """ ch_factor = res2_bottleneck * group // 64 shortcut = l l = Conv2D('conv1', l, ch_out * ch_factor, 1, strides=1, activation=BNReLU) l = Conv2D('conv2', l, ch_out * ch_factor, 3, strides=stride, activation=BNReLU, split=group) """ ImageNet in 1 Hour, Sec 5.1: """ l = Conv2D('conv3', l, ch_out * 4, 1, activation=get_bn(zero_init=True)) """ """ ret = l + resnet_shortcut(shortcut, ch_out * 4, stride, activation=get_bn(zero_init=False)) return tf.nn.relu(ret, name='block_output')
def atrous_spatial_pyramid_pooling(logits): # Compute the ASPP. logits_size = tf.shape(logits)[1:3] with argscope(Conv2D, filters=256, kernel_size=3, activation=BNReLU): ASPP_1 = Conv2D('aspp_conv1', logits, kernel_size=1) ASPP_2 = Conv2D('aspp_conv2', logits, dilation_rate=cfg.atrous_rates[0]) ASPP_3 = Conv2D('aspp_conv3', logits, dilation_rate=cfg.atrous_rates[1]) ASPP_4 = Conv2D('aspp_conv4', logits, dilation_rate=cfg.atrous_rates[2]) # ImagePooling = GlobalAvgPooling('image_pooling', logits) ImagePooling = tf.reduce_mean(logits, [1, 2], name='global_average_pooling', keepdims=True) image_level_features = Conv2D('image_level_conv', ImagePooling, kernel_size=1) image_level_features = tf.image.resize_bilinear(image_level_features, logits_size, name='upsample') output = tf.concat([ASPP_1, ASPP_2, ASPP_3, ASPP_4, image_level_features], -1, name='concat') output = Conv2D('conv_after_concat', output, 256, 1, activation=BNReLU) return output
def resnet_bottleneck(layer, ch_out, stride): shortcut = layer if cfg.BACKBONE.STRIDE_1X1: if stride == 2: layer = layer[:, :, :-1, :-1] layer = Conv2D('conv1', layer, ch_out, 1, strides=stride) layer = Conv2D('conv2', layer, ch_out, 3, strides=1) else: layer = Conv2D('conv1', layer, ch_out, 1, strides=1) if stride == 2: layer = tf.pad(layer, [[0, 0], [0, 0], maybe_reverse_pad(0, 1), maybe_reverse_pad(0, 1)]) layer = Conv2D('conv2', layer, ch_out, 3, strides=2, padding='VALID') else: layer = Conv2D('conv2', layer, ch_out, 3, strides=stride) layer = Conv2D('conv3', layer, ch_out * 4, 1, activation=get_norm(zero_init=True)) ret = layer + resnet_shortcut( shortcut, ch_out * 4, stride, activation=get_norm(zero_init=False)) return tf.nn.relu(ret, name='output')
def resnet_shortcut(l, n_out, stride, activation=tf.identity): data_format = get_arg_scope()['Conv2D']['data_format'] n_in = l.get_shape().as_list()[1 if data_format in ['NCHW', 'channels_first'] else 3] if n_in != n_out: # change dimension when channel is not the same return Conv2D('convshortcut', l, n_out, 1, strides=stride, activation=activation) else: return l
def LinearBottleneck(x, ich, och, kernel, padding='SAME', stride=1, activation=None, t=3, use_ab=False, w_init=None): ''' mobilenetv2 linear bottlenet. ''' if activation is None: activation = BNReLU if kernel > 3 else BNOnly out = Conv2D('conv_e', x, int(ich * t), 1, activation=BNReLU) out = DWConv('conv_d', out, kernel, padding, stride, w_init, activation) if use_ab and activation == BNReLU: out = AccuracyBoost('ab', out) out = Conv2D('conv_p', out, och, 1, activation=BNOnly) return out
def se_resnet_bottleneck(l, ch_out, stride): shortcut = l l = Conv2D('conv1', l, ch_out, 1, activation=BNReLU) l = Conv2D('conv2', l, ch_out, 3, strides=stride, activation=BNReLU) l = Conv2D('conv3', l, ch_out * 4, 1, activation=get_bn(zero_init=True)) squeeze = GlobalAvgPooling('gap', l) squeeze = FullyConnected('fc1', squeeze, ch_out // 4, activation=tf.nn.relu) squeeze = FullyConnected('fc2', squeeze, ch_out * 4, activation=tf.nn.sigmoid) data_format = get_arg_scope()['Conv2D']['data_format'] ch_ax = 1 if data_format in ['NCHW', 'channels_first'] else 3 shape = [-1, 1, 1, 1] shape[ch_ax] = ch_out * 4 l = l * tf.reshape(squeeze, shape) return l + resnet_shortcut( shortcut, ch_out * 4, stride, activation=get_bn(zero_init=False))