def complexity(cx, w_in, w_se): h, w = cx["h"], cx["w"] cx["h"], cx["w"] = 1, 1 cx = net.complexity_conv2d(cx, w_in, w_se, 1, 1, 0, bias=True) cx = net.complexity_conv2d(cx, w_se, w_in, 1, 1, 0, bias=True) cx["h"], cx["w"] = h, w return cx
def complexity(cx, w_in, w_out, stride, w_b=None, num_gs=1): err_str = "Basic transform does not support w_b and num_gs options" assert w_b is None and num_gs == 1, err_str cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1) cx = net.complexity_batchnorm2d(cx, w_out) cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1) cx = net.complexity_batchnorm2d(cx, w_out) return cx
def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None): err_str = "Vanilla block does not support bm, gw, and se_r options" assert bm is None and gw is None and se_r is None, err_str cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1) cx = net.complexity_batchnorm2d(cx, w_out) cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1) cx = net.complexity_batchnorm2d(cx, w_out) return cx
def complexity(cx, w_in, w_out, stride, w_b, num_gs): (s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride) cx = net.complexity_conv2d(cx, w_in, w_b, 1, s1, 0) cx = net.complexity_batchnorm2d(cx, w_b) cx = net.complexity_conv2d(cx, w_b, w_b, 3, s3, 1, num_gs) cx = net.complexity_batchnorm2d(cx, w_b) cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0) cx = net.complexity_batchnorm2d(cx, w_out) return cx
def complexity(cx, w_in, exp_r, kernel, stride, se_r, w_out): w_exp = int(w_in * exp_r) if w_exp != w_in: cx = net.complexity_conv2d(cx, w_in, w_exp, 1, 1, 0) cx = net.complexity_batchnorm2d(cx, w_exp) padding = (kernel - 1) // 2 cx = net.complexity_conv2d(cx, w_exp, w_exp, kernel, stride, padding, w_exp) cx = net.complexity_batchnorm2d(cx, w_exp) cx = SE.complexity(cx, w_exp, int(w_in * se_r)) cx = net.complexity_conv2d(cx, w_exp, w_out, 1, 1, 0) cx = net.complexity_batchnorm2d(cx, w_out) return cx
def complexity(cx, w_in, w_out, stride, bm, gw, se_r): w_b = int(round(w_out * bm)) g = w_b // gw cx = net.complexity_conv2d(cx, w_in, w_b, 1, 1, 0) cx = net.complexity_batchnorm2d(cx, w_b) cx = net.complexity_conv2d(cx, w_b, w_b, 3, stride, 1, g) cx = net.complexity_batchnorm2d(cx, w_b) if se_r: w_se = int(round(w_in * se_r)) cx = SE.complexity(cx, w_b, w_se) cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0) cx = net.complexity_batchnorm2d(cx, w_out) return cx
def complexity(cx, w_in, w_out, w_mid, ksize, stride, proj): cx_h, cx_w = cx["h"], cx["w"] cx["h"], cx["w"] = 1, 1 cx = net.complexity_conv2d(cx, w_in, max(16, w_in // 16), 1, 1, 0, bias=True) cx["h"], cx["w"] = cx_h, cx_w if proj: cx = WeightNet_DW.complexity(cx, w_in, ksize, stride) cx = net.complexity_batchnorm2d(cx, w_in) cx = WeightNet.complexity(cx, w_in, w_in, 1, 1) cx = net.complexity_batchnorm2d(cx, w_in) cx = WeightNet.complexity(cx, w_in, w_mid, 1, 1) cx = net.complexity_batchnorm2d(cx, w_mid) cx = WeightNet_DW.complexity(cx, w_mid, ksize, stride) cx = net.complexity_batchnorm2d(cx, w_mid) cx = WeightNet.complexity(cx, w_mid, w_out, 1, 1) cx = net.complexity_batchnorm2d(cx, w_out) else: # TODO: add the complexity of channel_shuffle cx = WeightNet.complexity(cx, w_in, w_mid, 1, 1) cx = net.complexity_batchnorm2d(cx, w_mid) cx = WeightNet_DW.complexity(cx, w_mid, ksize, stride) cx = net.complexity_batchnorm2d(cx, w_mid) cx = WeightNet.complexity(cx, w_mid, w_out, 1, 1) cx = net.complexity_batchnorm2d(cx, w_out) return cx
def complexity(cx, w_in, head_channels, nc): cx = net.complexity_conv2d(cx, w_in, head_channels[0], 1, 1, 0) cx = net.complexity_batchnorm2d(cx, head_channels[0]) previous_channel = head_channels[0] cx["h"], cx["w"] = 1, 1 for _channel in head_channels[1:]: cx = net.complexity_conv2d(cx, previous_channel, _channel, 1, 1, 0) # cx = net.complexity_batchnorm2d(cx, _channel) previous_channel = _channel cx = net.complexity_conv2d(cx, head_channels[-1], nc, 1, 1, 0, bias=True) return cx
def complexity(cx, w_in, w_out, stride, trans_fun, w_b, num_gs): proj_block = (w_in != w_out) or (stride != 1) if proj_block: h, w = cx["h"], cx["w"] cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0) cx = net.complexity_batchnorm2d(cx, w_out) cx["h"], cx["w"] = h, w # parallel branch cx = trans_fun.complexity(cx, w_in, w_out, stride, w_b, num_gs) return cx
def complexity(cx, w_in, w_out, stride, bm=1.0, gw=1, se_r=None): proj_block = (w_in != w_out) or (stride != 1) if proj_block: h, w = cx["h"], cx["w"] cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0) cx = net.complexity_batchnorm2d(cx, w_out) cx["h"], cx["w"] = h, w # parallel branch cx = BottleneckTransform.complexity(cx, w_in, w_out, stride, bm, gw, se_r) return cx
def complexity(cx, w_in, w_out, stride, bm=None, gw=None, se_r=None): err_str = "Basic transform does not support bm, gw, and se_r options" assert bm is None and gw is None and se_r is None, err_str proj_block = (w_in != w_out) or (stride != 1) if proj_block: h, w = cx["h"], cx["w"] cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0) cx = net.complexity_batchnorm2d(cx, w_out) cx["h"], cx["w"] = h, w # parallel branch cx = BasicTransform.complexity(cx, w_in, w_out, stride) return cx
def complexity(cx, w_in, w_out, w_mid, ksize, stride, proj): if proj: cx = net.complexity_conv2d(cx, w_in, w_in, ksize, stride, ksize // 2, groups=w_in, bias=False) cx = net.complexity_batchnorm2d(cx, w_in) cx = net.complexity_conv2d(cx, w_in, w_in, 1, 1, 0, bias=False) cx = net.complexity_batchnorm2d(cx, w_in) cx = net.complexity_conv2d(cx, w_in, w_mid, 1, 1, 0, bias=False) cx = net.complexity_batchnorm2d(cx, w_mid) cx = net.complexity_conv2d(cx, w_mid, w_mid, ksize, 1, ksize // 2, groups=w_mid, bias=False) cx = net.complexity_batchnorm2d(cx, w_mid) cx = net.complexity_conv2d(cx, w_mid, w_out, 1, 1, 0, bias=False) cx = net.complexity_batchnorm2d(cx, w_out) else: # TODO: add the complexity of channel_shuffle cx = net.complexity_conv2d(cx, w_in, w_mid, 1, 1, 0, bias=False) cx = net.complexity_batchnorm2d(cx, w_mid) cx = net.complexity_conv2d(cx, w_mid, w_mid, ksize, stride, ksize // 2, groups=w_mid, bias=False) cx = net.complexity_batchnorm2d(cx, w_mid) cx = net.complexity_conv2d(cx, w_mid, w_out, 1, 1, 0, bias=False) cx = net.complexity_batchnorm2d(cx, w_out) return cx
def complexity(cx, w_in, w_out): cx = net.complexity_conv2d(cx, w_in, w_out, 3, 2, 1) cx = net.complexity_batchnorm2d(cx, w_out) cx = net.complexity_maxpool2d(cx, 3, 2, 1) return cx
def complexity(cx, w_in, w_out, nc): cx = net.complexity_conv2d(cx, w_in, w_out, 1, 1, 0) cx = net.complexity_batchnorm2d(cx, w_out) cx["h"], cx["w"] = 1, 1 cx = net.complexity_conv2d(cx, w_out, nc, 1, 1, 0, bias=True) return cx
def complexity(cx, w_in, nc): cx["h"], cx["w"] = 1, 1 cx = net.complexity_conv2d(cx, w_in, nc, 1, 1, 0, bias=True) return cx
def complexity(cx, w_in, w_out): cx = net.complexity_conv2d(cx, w_in, w_out, 1, 1, 0, bias=False) cx = net.complexity_batchnorm2d(cx, w_out) return cx
def complexity(cx, w_in, w_out): cx = net.complexity_conv2d(cx, w_in, w_out, 3, 1, 1) cx = net.complexity_batchnorm2d(cx, w_out) return cx