コード例 #1
0
 def complexity(cx, w_in, w_out, stride, params):
     w_b = int(round(w_out * params["bot_mul"]))
     w_se = int(round(w_in * params["se_r"]))
     w_se1 = int(round(w_in * params["se1_r"]))
     groups = w_b // params["group_w"]
     cx = conv2d_cx(cx, w_in, w_b, 1)
     cx = norm2d_cx(cx, w_b)
     cx = conv2d_cx(cx, w_b, w_b, 3, stride=stride, groups=groups)
     cx = norm2d_cx(cx, w_b)
     cx = SE.complexity(cx, w_b, w_se) if w_se else cx
     cx = C_SE.complexity(cx, w_b, w_se) if params['c_se'] else cx
     cx = W_SE.complexity(cx, w_b, w_se1) if params['w_se'] else cx
     if params['ew_se']:
         if params['block_idx'] in params['w_se_idx']:
             cx = SE.complexity(cx, w_b, w_se1)
         else:
             cx = EW_SE.complexity(cx, w_b, w_se1)
     cx = W1_SE.complexity(cx, w_b, w_se1) if params['w1_se'] else cx
     cx = W13_SE.complexity(cx, w_b, w_se1) if params['w13_se'] else cx
     cx = SE_GAP.complexity(cx, w_b, w_se1) if params['se_gap'] else cx
     cx = SE_GAP1.complexity(cx, w_b, w_se1) if params['se_gap1'] else cx
     cx = SE_GAP_DW.complexity(cx, w_b) if params['se_gap_dw'] else cx
     cx = conv2d_cx(cx, w_b, w_out, 1)
     cx = norm2d_cx(cx, w_out)
     return cx
コード例 #2
0
ファイル: resnet.py プロジェクト: llwx593/Efficientnet
 def complexity(cx, w_in, w_out, stride, w_b=None, groups=1):
     err_str = "Basic transform does not support w_b and groups options"
     assert w_b is None and groups == 1, err_str
     cx = conv2d_cx(cx, w_in, w_out, 3, stride=stride)
     cx = norm2d_cx(cx, w_out)
     cx = conv2d_cx(cx, w_out, w_out, 3)
     cx = norm2d_cx(cx, w_out)
     return cx
コード例 #3
0
ファイル: resnet.py プロジェクト: llwx593/Efficientnet
 def complexity(cx, w_in, w_out, stride, w_b, groups):
     (s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride)
     cx = conv2d_cx(cx, w_in, w_b, 1, stride=s1)
     cx = norm2d_cx(cx, w_b)
     cx = conv2d_cx(cx, w_b, w_b, 3, stride=s3, groups=groups)
     cx = norm2d_cx(cx, w_b)
     cx = conv2d_cx(cx, w_b, w_out, 1)
     cx = norm2d_cx(cx, w_out)
     return cx
コード例 #4
0
 def complexity(cx, w_in, exp_r, k, stride, se_r, w_out):
     w_exp = int(w_in * exp_r)
     if w_exp != w_in:
         cx = conv2d_cx(cx, w_in, w_exp, 1)
         cx = norm2d_cx(cx, w_exp)
     cx = conv2d_cx(cx, w_exp, w_exp, k, stride=stride, groups=w_exp)
     cx = norm2d_cx(cx, w_exp)
     cx = SE.complexity(cx, w_exp, int(w_in * se_r))
     cx = conv2d_cx(cx, w_exp, w_out, 1)
     cx = norm2d_cx(cx, w_out)
     return cx
コード例 #5
0
 def complexity(cx, w_in, w_out, stride, params):
     w_b = int(round(w_out * params["bot_mul"]))
     w_se = int(round(w_in * params["se_r"]))
     groups = w_b // params["group_w"]
     cx = conv2d_cx(cx, w_in, w_b, 1)
     cx = norm2d_cx(cx, w_b)
     cx = conv2d_cx(cx, w_b, w_b, 3, stride=stride, groups=groups)
     cx = norm2d_cx(cx, w_b)
     cx = SE.complexity(cx, w_b, w_se) if w_se else cx
     cx = conv2d_cx(cx, w_b, w_out, 1)
     cx = norm2d_cx(cx, w_out)
     return cx
コード例 #6
0
 def complexity(cx, w_in, w_out, stride, params):
     if (w_in != w_out) or (stride != 1):
         h, w = cx["h"], cx["w"]
         cx = conv2d_cx(cx, w_in, w_out, 1, stride=stride)
         cx = norm2d_cx(cx, w_out)
         cx["h"], cx["w"] = h, w
     cx = BottleneckTransform.complexity(cx, w_in, w_out, stride, params)
     return cx
コード例 #7
0
ファイル: resnet.py プロジェクト: llwx593/Efficientnet
 def complexity(cx, w_in, w_out, stride, trans_fun, w_b, groups):
     if (w_in != w_out) or (stride != 1):
         h, w = cx["h"], cx["w"]
         cx = conv2d_cx(cx, w_in, w_out, 1, stride=stride)
         cx = norm2d_cx(cx, w_out)
         cx["h"], cx["w"] = h, w
     cx = trans_fun.complexity(cx, w_in, w_out, stride, w_b, groups)
     return cx
コード例 #8
0
 def complexity(cx, w_in, head_width, num_classes):
     if head_width > 0:
         cx = conv2d_cx(cx, w_in, head_width, 1)
         cx = norm2d_cx(cx, head_width)
         w_in = head_width
     cx = gap2d_cx(cx, w_in)
     cx = linear_cx(cx, w_in, num_classes, bias=True)
     return cx
コード例 #9
0
 def complexity(cx, w_in, ks, ws, ss):
     for i, (k, w_out, stride) in enumerate(zip(ks, ws, ss)):
         if i < len(ks) - 1:
             cx = conv2d_cx(cx, w_in, w_out, 3, stride=stride)
             cx = norm2d_cx(cx, w_out)
         else:
             cx = conv2d_cx(cx, w_in, w_out, k, stride=stride, bias=True)
         w_in = w_out
     return cx
コード例 #10
0
 def complexity(cx, w_in, w_out, num_classes):
     cx = conv2d_cx(cx, w_in, w_out, 1)
     cx = norm2d_cx(cx, w_out)
     cx = gap2d_cx(cx, w_out)
     cx = linear_cx(cx, w_out, num_classes, bias=True)
     return cx
コード例 #11
0
 def complexity(cx, w_in, w_out):
     cx = conv2d_cx(cx, w_in, w_out, 3, stride=2)
     cx = norm2d_cx(cx, w_out)
     return cx
コード例 #12
0
 def complexity(cx, w_in, w_out, stride, _params):
     cx = conv2d_cx(cx, w_in, w_out, 3, stride=stride)
     cx = norm2d_cx(cx, w_out)
     cx = conv2d_cx(cx, w_out, w_out, 3)
     cx = norm2d_cx(cx, w_out)
     return cx