def relu_conv_bn(inp, ker_shape, stride, padding, out_chl, isrelu=True, isbn=True): global idx idx += 1 if isrelu: inp = arith.ReLU(inp) inp = Conv2D("conv{}".format(idx), inp, kernel_shape=ker_shape, stride=stride, padding=padding, output_nr_channel=out_chl, nonlinearity=Identity()) if isbn: inp = BN("bn{}".format(idx), inp, eps=1e-9) inp = ElementwiseAffine("bnaff{}".format(idx), inp, shared_in_channels=False, k=C(1), b=C(0)) return inp
def skip(inp, isdown, chl): if isdown == -1: return inp global idx l1 = inp if isdown != 0: l1 = Pooling2D("pooling1_{}".format(idx), inp, window=1, stride=2, mode="AVERAGE") l1 = relu_conv_bn(l1, 1, 1, 0, chl // 2, isrelu=False, isbn=False) l2 = inp if isdown != 0: l2 = Pooling2D("pooling2_{}".format(idx), inp[:, :, 1:, 1:], window=1, stride=2, mode="AVERAGE") l2 = relu_conv_bn(l2, 1, 1, 0, chl // 2, isrelu=False, isbn=False) lay = O.Concat([l1, l2], axis=1) lay = BN("bn_down_{}".format(isdown), lay, eps=1e-9) lay = ElementwiseAffine("bnaff_down_{}".format(isdown), lay, shared_in_channels=False, k=C(1), b=C(0)) return lay
def bn_relu_conv(inp, ker_shape, stride, padding, out_chl, has_relu, has_bn, has_conv = True): global idx idx += 1 if has_bn: l1 = BN("bn{}".format(idx), inp, eps = 1e-9) l1 = ElementwiseAffine("bnaff{}".format(idx), l1, shared_in_channels = False, k = C(1), b = C(0)) else: l1 = inp if has_relu: l2 = arith.ReLU(l1) else: l2 = l1 if not has_conv: return l2 l3 = Conv2D( "conv{}".format(idx), l2, kernel_shape = ker_shape, stride = stride, padding = padding, output_nr_channel = out_chl, #W = G(mean = 0, std = (1 / (ker_shape**2 * inp.partial_shape[1]))**0.5), #b = C(0), nonlinearity = Identity() ) return l3
def conv_bn(inp, ker_shape, stride, padding, out_chl, isrelu, mode = None): global idx idx += 1 print(inp.partial_shape, ker_shape, out_chl) if ker_shape == 1: W = ortho_group.rvs(out_chl) W = W[:, :inp.partial_shape[1]] W = W.reshape(W.shape[0], W.shape[1], 1, 1) W = ConstProvider(W) b = ConstProvider(np.zeros(out_chl)) else: W = G(mean = 0, std = ((1 + int(isrelu)) / (ker_shape**2 * inp.partial_shape[1]))**0.5) b = C(0) l1 = Conv2D( "conv{}".format(idx), inp, kernel_shape = ker_shape, stride = stride, padding = padding, output_nr_channel = out_chl, group = mode, W = W, b = b, nonlinearity = Identity() ) l2 = BN("bn{}".format(idx), l1, eps = 1e-9) l2 = ElementwiseAffine("bnaff{}".format(idx), l2, shared_in_channels = False, k = C(1), b = C(0)) if isrelu: l2 = arith.ReLU(l2) return l2, l1
def bn_relu_conv(inp, ker_shape, stride, padding, out_chl, has_relu, has_bn, has_conv = True): global idx idx += 1 if has_bn: l1 = BN("bn{}".format(idx), inp, eps = 1e-9) l1 = ElementwiseAffine("bnaff{}".format(idx), l1, shared_in_channels = False, k = C(1), b = C(0)) else: l1 = inp if has_relu: l2 = arith.ReLU(l1) else: l2 = l1 if not has_conv: return l2, None l3 = Conv2D( "conv{}".format(idx), l2, kernel_shape = ker_shape, stride = stride, padding = padding, output_nr_channel = out_chl, nonlinearity = Identity() ) w = l3.inputs[1] assert ":W" in w.name return l3, w
def deconv_bn_relu(name, inp, kernel_shape = None, stride = None, padding = None, output_nr_channel = None, isbnrelu = True): lay = O.Deconv2DVanilla(name, inp, kernel_shape = kernel_shape, stride = stride, padding = padding, output_nr_channel = output_nr_channel) if isbnrelu: lay = BN(name + "bn", lay, eps = 1e-9) lay = ElementwiseAffine(name + "bnaff", lay, shared_in_channels = False, k = C(1), b = C(0)) lay = arith.ReLU(lay) return lay
def conv_bn(inp, ker_shape, stride, padding, out_chl, isrelu): global idx idx += 1 l1 = Conv2D( "conv{}".format(idx), inp, kernel_shape = ker_shape, stride = stride, padding = padding, output_nr_channel = out_chl, nonlinearity = Identity() ) l2 = BN("bn{}".format(idx), l1, eps = 1e-9) l2 = ElementwiseAffine("bnaff{}".format(idx), l2, shared_in_channels = False, k = C(1), b = C(0)) if isrelu: l2 = arith.ReLU(l2) return l2
def conv_bn(inp, ker_shape, stride, padding, out_chl, isrelu, group = 1, shift = 0): global idx idx += 1 if group == 1: l1 = Conv2D( "conv{}".format(idx), inp, kernel_shape = ker_shape, stride = stride, padding = padding, output_nr_channel = out_chl, #W = G(mean = 0, std = ((1) / (ker_shape**2 * inp.partial_shape[1]))**0.5), #b = C(0), nonlinearity = Identity() ) else: if shift == 0: l1 = Conv2D( "conv{}".format(idx), inp, kernel_shape = ker_shape, stride = stride, padding = padding, output_nr_channel = out_chl, #W = G(mean = 0, std = ((1) / (ker_shape**2 * inp.partial_shape[1]))**0.5), #b = C(0), nonlinearity = Identity(), group = group, ) else: shift = 1 l1 = inp while shift != group: l11 = Conv2D( "conv{}_{}_1".format(idx, shift), l1, kernel_shape = ker_shape, stride = stride, padding = padding, output_nr_channel = out_chl, #W = G(mean = 0, std = ((1) / (ker_shape**2 * inp.partial_shape[1]))**0.5), #b = C(0), nonlinearity = Identity(), group = group, ) inp_chl = l1.partial_shape[1] l1 = O.Concat([l1[:, shift * inp_chl // group:, :, :], l1[:, :shift * inp_chl // group, :, :]], axis = 1) l12 = Conv2D( "conv{}_{}_2".format(idx, shift), l1, kernel_shape = ker_shape, stride = stride, padding = padding, output_nr_channel = out_chl, #W = G(mean = 0, std = ((1) / (ker_shape**2 * inp.partial_shape[1]))**0.5), #b = C(0), nonlinearity = Identity(), group = group, ) l1 = l11 + l12 shift *= 2 l2 = BN("bn{}".format(idx), l1, eps = 1e-9) l2 = ElementwiseAffine("bnaff{}".format(idx), l2, shared_in_channels = False, k = C(1), b = C(0)) if isrelu: l2 = arith.ReLU(l2) return l2
def conv_bn(inp, ker_shape, stride, padding, out_chl, isrelu): global idx idx += 1 l1 = Conv2D( "encoder_conv{}".format(idx), inp, kernel_shape = ker_shape, stride = stride, padding = padding, output_nr_channel = out_chl, W = G(mean = 0, std = ((1 + int(isrelu)) / (ker_shape**2 * inp.partial_shape[1]))**0.5), nonlinearity = Identity() ) l2 = BN("encoder_bn{}".format(idx), l1, eps = 1e-9) l2 = ElementwiseAffine("bnaff{}".format(idx), l2, shared_in_channels = False, k = C(1), b = C(0)) if isrelu: l2 = arith.ReLU(l2) return l2, l1
def res_layer(inp, chl, stride=1, proj=False): pre = inp inp = conv_bn(inp, 1, stride, 0, chl // 4, True) inp = conv_bn(inp, 3, 1, 1, chl // 4, True) inp = conv_bn(inp, 1, 1, 0, chl, False) name = inp.name inp = ElementwiseAffine("aff({})".format(name), inp, shared_in_channels=False, k=C(0.5), b=C(0)) if proj: pre = conv_bn(pre, 1, stride, 0, chl, False) inp = arith.ReLU(inp + pre) return inp
def bn_relu_conv(inp, ker_shape, stride, padding, out_chl, isrelu, isbn): global idx idx += 1 if isbn: inp = BN("bn{}".format(idx), inp, eps = 1e-9) inp = ElementwiseAffine("bnaff{}".format(idx), inp, shared_in_channels = False, k = C(1), b = C(0)) if isrelu: inp = arith.ReLU(inp) inp = Conv2D( "conv{}".format(idx), inp, kernel_shape = ker_shape, stride = stride, padding = padding, output_nr_channel = out_chl, #W = G(mean = 0, std = ((1) / (ker_shape**2 * inp.partial_shape[1]))**0.5), #b = C(0), nonlinearity = Identity() ) return inp
def create_bn_relu_spatialconv(prefix, f_in, ksize, stride, pad, num_outputs, has_bn=True, has_relu=True, conv_name_fun=None, bn_name_fun=None): conv_name = prefix if conv_name_fun: conv_name = conv_name_fun(prefix) spatial_conv_name = conv_name + "_s" f = Conv2DVanilla(spatial_conv_name, f_in, kernel_shape=ksize, group='chan', output_nr_channel=f_in.partial_shape[1], stride=stride, padding=pad) f = Conv2D(conv_name, f, kernel_shape=1, stride=1, padding=0, output_nr_channel=num_outputs, nonlinearity=mgsk.opr.helper.elemwise_trans.Identity()) if has_bn: bn_name = "bn_" + prefix if bn_name_fun: bn_name = bn_name_fun(prefix) f = BatchNormalization(bn_name, f, eps=1e-9) f = ElementwiseAffine(bn_name + "_scaleshift", f, shared_in_channels=False) f.get_param_shape("k") if has_relu: f = ReLU(f) return f
def conv_wn(inp, ker_shape, stride, padding, out_chl, isrelu): global idx idx += 1 l1 = Conv2D( "conv{}".format(idx), inp, kernel_shape = ker_shape, stride = stride, padding = padding, output_nr_channel = out_chl, W = G(mean = 0, std = 0.05), nonlinearity = Identity() ) W = l1.inputs[1] #l2 = BN("bn{}".format(idx), l1, eps = 1e-9) w = l1.inputs[1] assert ":W" in w.name w = (w**2).sum(axis = 3).sum(axis = 2).sum(axis = 1)**0.5 l1 = l1 / w.dimshuffle('x', 0, 'x', 'x') l2 = ElementwiseAffine("bnaff{}".format(idx), l1, shared_in_channels = False, k = C(1), b = C(0)) if isrelu: l2 = arith.ReLU(l2) return l2, l1, W
def conv_norm(inp, ker_shape, stride, padding, out_chl, isrelu): global idx idx += 1 inp = Conv2D("conv{}".format(idx), inp, kernel_shape=ker_shape, stride=stride, padding=padding, output_nr_channel=out_chl, nonlinearity=Identity()) mean = inp.mean(axis=3).mean(axis=2) std = ((inp - mean.dimshuffle(0, 1, 'x', 'x'))**2).mean(axis=3).mean(axis=2)**0.5 inp = (inp - mean.dimshuffle(0, 1, 'x', 'x')) / std.dimshuffle( 0, 1, 'x', 'x') inp = ElementwiseAffine("aff{}".format(idx), inp, shared_in_channels=False, k=C(1), b=C(0)) if isrelu: inp = O.ReLU(inp) return inp
def conv_bn(inp, ker_shape, stride, padding, out_chl, isrelu): global idx idx += 1 l10 = Conv2D("conv{}_0".format(idx), inp, kernel_shape=ker_shape, stride=stride, padding=padding, output_nr_channel=out_chl // 2, W=G(mean=0, std=((1 + int(isrelu)) / (ker_shape**2 * inp.partial_shape[1]))**0.5), nonlinearity=Identity()) l11 = Conv2D("conv{}_1".format(idx), inp, kernel_shape=ker_shape, stride=stride, padding=padding, output_nr_channel=out_chl // 2, W=G(mean=0, std=((1 + int(isrelu)) / (ker_shape**2 * inp.partial_shape[1]))**0.5), nonlinearity=Identity()) W = l11.inputs[1].owner_opr b = l11.inputs[2].owner_opr W.set_freezed() b.set_freezed() l1 = Concat([l10, l11], axis=1) l2 = BN("bn{}".format(idx), l1, eps=1e-9) l2 = ElementwiseAffine("bnaff{}".format(idx), l2, shared_in_channels=False, k=C(1), b=C(0)) if isrelu: l2 = arith.ReLU(l2) return l2, l1