def complexity(cx, w_in, w_out, stride, params): w_b = int(round(w_out * params["bot_mul"])) w_se = int(round(w_in * params["se_r"])) w_se1 = int(round(w_in * params["se1_r"])) groups = w_b // params["group_w"] cx = conv2d_cx(cx, w_in, w_b, 1) cx = norm2d_cx(cx, w_b) cx = conv2d_cx(cx, w_b, w_b, 3, stride=stride, groups=groups) cx = norm2d_cx(cx, w_b) cx = SE.complexity(cx, w_b, w_se) if w_se else cx cx = C_SE.complexity(cx, w_b, w_se) if params['c_se'] else cx cx = W_SE.complexity(cx, w_b, w_se1) if params['w_se'] else cx if params['ew_se']: if params['block_idx'] in params['w_se_idx']: cx = SE.complexity(cx, w_b, w_se1) else: cx = EW_SE.complexity(cx, w_b, w_se1) cx = W1_SE.complexity(cx, w_b, w_se1) if params['w1_se'] else cx cx = W13_SE.complexity(cx, w_b, w_se1) if params['w13_se'] else cx cx = SE_GAP.complexity(cx, w_b, w_se1) if params['se_gap'] else cx cx = SE_GAP1.complexity(cx, w_b, w_se1) if params['se_gap1'] else cx cx = SE_GAP_DW.complexity(cx, w_b) if params['se_gap_dw'] else cx cx = conv2d_cx(cx, w_b, w_out, 1) cx = norm2d_cx(cx, w_out) return cx
def complexity(cx, w_in, w_out, stride, w_b=None, groups=1): err_str = "Basic transform does not support w_b and groups options" assert w_b is None and groups == 1, err_str cx = conv2d_cx(cx, w_in, w_out, 3, stride=stride) cx = norm2d_cx(cx, w_out) cx = conv2d_cx(cx, w_out, w_out, 3) cx = norm2d_cx(cx, w_out) return cx
def complexity(cx, w_in, w_out, stride, w_b, groups): (s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride) cx = conv2d_cx(cx, w_in, w_b, 1, stride=s1) cx = norm2d_cx(cx, w_b) cx = conv2d_cx(cx, w_b, w_b, 3, stride=s3, groups=groups) cx = norm2d_cx(cx, w_b) cx = conv2d_cx(cx, w_b, w_out, 1) cx = norm2d_cx(cx, w_out) return cx
def complexity(cx, w_in, exp_r, k, stride, se_r, w_out): w_exp = int(w_in * exp_r) if w_exp != w_in: cx = conv2d_cx(cx, w_in, w_exp, 1) cx = norm2d_cx(cx, w_exp) cx = conv2d_cx(cx, w_exp, w_exp, k, stride=stride, groups=w_exp) cx = norm2d_cx(cx, w_exp) cx = SE.complexity(cx, w_exp, int(w_in * se_r)) cx = conv2d_cx(cx, w_exp, w_out, 1) cx = norm2d_cx(cx, w_out) return cx
def complexity(cx, w_in, w_out, stride, params): w_b = int(round(w_out * params["bot_mul"])) w_se = int(round(w_in * params["se_r"])) groups = w_b // params["group_w"] cx = conv2d_cx(cx, w_in, w_b, 1) cx = norm2d_cx(cx, w_b) cx = conv2d_cx(cx, w_b, w_b, 3, stride=stride, groups=groups) cx = norm2d_cx(cx, w_b) cx = SE.complexity(cx, w_b, w_se) if w_se else cx cx = conv2d_cx(cx, w_b, w_out, 1) cx = norm2d_cx(cx, w_out) return cx
def complexity(cx, w_in, w_out, stride, params): if (w_in != w_out) or (stride != 1): h, w = cx["h"], cx["w"] cx = conv2d_cx(cx, w_in, w_out, 1, stride=stride) cx = norm2d_cx(cx, w_out) cx["h"], cx["w"] = h, w cx = BottleneckTransform.complexity(cx, w_in, w_out, stride, params) return cx
def complexity(cx, w_in, w_out, stride, trans_fun, w_b, groups): if (w_in != w_out) or (stride != 1): h, w = cx["h"], cx["w"] cx = conv2d_cx(cx, w_in, w_out, 1, stride=stride) cx = norm2d_cx(cx, w_out) cx["h"], cx["w"] = h, w cx = trans_fun.complexity(cx, w_in, w_out, stride, w_b, groups) return cx
def complexity(cx, w_in, head_width, num_classes): if head_width > 0: cx = conv2d_cx(cx, w_in, head_width, 1) cx = norm2d_cx(cx, head_width) w_in = head_width cx = gap2d_cx(cx, w_in) cx = linear_cx(cx, w_in, num_classes, bias=True) return cx
def complexity(cx, w_in, ks, ws, ss): for i, (k, w_out, stride) in enumerate(zip(ks, ws, ss)): if i < len(ks) - 1: cx = conv2d_cx(cx, w_in, w_out, 3, stride=stride) cx = norm2d_cx(cx, w_out) else: cx = conv2d_cx(cx, w_in, w_out, k, stride=stride, bias=True) w_in = w_out return cx
def complexity(cx, w_in, w_out, num_classes): cx = conv2d_cx(cx, w_in, w_out, 1) cx = norm2d_cx(cx, w_out) cx = gap2d_cx(cx, w_out) cx = linear_cx(cx, w_out, num_classes, bias=True) return cx
def complexity(cx, w_in, w_out): cx = conv2d_cx(cx, w_in, w_out, 3, stride=2) cx = norm2d_cx(cx, w_out) return cx
def complexity(cx, w_in, w_out, stride, _params): cx = conv2d_cx(cx, w_in, w_out, 3, stride=stride) cx = norm2d_cx(cx, w_out) cx = conv2d_cx(cx, w_out, w_out, 3) cx = norm2d_cx(cx, w_out) return cx