예제 #1
0
        def __init__(self,
                     C=1,
                     N=1,
                     K=1,
                     D=1,
                     H=1,
                     W=1,
                     T=1,
                     R=1,
                     S=1,
                     pad_d=0,
                     pad_h=0,
                     pad_w=0,
                     str_d=1,
                     str_h=1,
                     str_w=1):

            from ngraph.frontends.common.utils import conv_output_dim
            M = conv_output_dim(D, T, pad_d, str_d)
            P = conv_output_dim(H, R, pad_h, str_h)
            Q = conv_output_dim(W, S, pad_w, str_w)

            self.dimO = (K, M, P, Q, N)
            self.dimI = (C, D, H, W, N)
            self.dimF = (C, T, R, S, K)

            self.conv_params = dict(pad_d=pad_d,
                                    pad_h=pad_h,
                                    pad_w=pad_w,
                                    str_d=str_d,
                                    str_h=str_h,
                                    str_w=str_w,
                                    dil_d=1,
                                    dil_h=1,
                                    dil_w=1)

            self.batch_axis = ng.make_axis(name='N', length=N)

            self.ax_i = ng.make_axes([
                ng.make_axis(name='C', length=C),
                ng.make_axis(name='D', length=D),
                ng.make_axis(name='H', length=H),
                ng.make_axis(name='W', length=W), self.batch_axis
            ])

            self.ax_f = ng.make_axes([
                ng.make_axis(name='C', length=C),
                ng.make_axis(name='D', length=T),
                ng.make_axis(name='H', length=R),
                ng.make_axis(name='W', length=S),
                ng.make_axis(name='K', length=K),
            ])

            self.ax_o = ng.make_axes([
                ng.make_axis(name='C', length=K),
                ng.make_axis(name='D', length=M),
                ng.make_axis(name='H', length=P),
                ng.make_axis(name='W', length=Q), self.batch_axis
            ])
예제 #2
0
파일: test_pool.py 프로젝트: QiJune/ngraph
    def __init__(self,
                 C=1,
                 N=1,
                 D=1,
                 H=1,
                 W=1,
                 J=1,
                 T=1,
                 R=1,
                 S=1,
                 pad_c=0,
                 pad_d=0,
                 pad_h=0,
                 pad_w=0,
                 str_c=1,
                 str_d=1,
                 str_h=1,
                 str_w=1,
                 op='max'):

        K = conv_output_dim(C, J, pad_c, str_c)
        M = conv_output_dim(D, T, pad_d, str_d)
        P = conv_output_dim(H, R, pad_h, str_h)
        Q = conv_output_dim(W, S, pad_w, str_w)

        self.dimO = (K, M, P, Q, N)
        self.dimI = (C, D, H, W, N)
        self.dimF = (J, T, R, S, K)

        self.pool_params = dict(pad_c=pad_c,
                                pad_d=pad_d,
                                pad_h=pad_h,
                                pad_w=pad_w,
                                str_c=str_c,
                                str_d=str_d,
                                str_h=str_h,
                                str_w=str_w,
                                J=J,
                                T=T,
                                R=R,
                                S=S,
                                op=op)

        batch_axis = ng.make_axis(name='N', length=N)

        self.ax_i = ng.make_axes([
            ng.make_axis(name='C', length=C),
            ng.make_axis(name='D', length=D),
            ng.make_axis(name='H', length=H),
            ng.make_axis(name='W', length=W), batch_axis
        ])

        self.ax_o = ng.make_axes([
            ng.make_axis(name='C', length=K),
            ng.make_axis(name='D', length=M),
            ng.make_axis(name='H', length=P),
            ng.make_axis(name='W', length=Q), batch_axis
        ])
예제 #3
0
def test_conv_inverts_deconv(input_size, filter_size, padding, stride):
    """ Test that conv and deconv are inverse operations given the same parameters"""

    # convolutions whose output size are not an even multiple of stride cannot be exactly inverted
    a = (input_size + sum(padding) - filter_size) % stride
    conv_output = utils.conv_output_dim(input_size, filter_size, padding,
                                        stride)
    deconv_output = utils.deconv_output_dim(conv_output, filter_size, padding,
                                            stride)

    assert deconv_output == (input_size - a), (
        "Convolution and Deconvolution do not invert:\n"
        "output ({}) != input ({}) - a ({})\n"
        "filter: {}, padding: {}, stride: {}").format(deconv_output,
                                                      input_size, a,
                                                      filter_size, padding,
                                                      stride)
예제 #4
0
    def __init__(self,
                 C=1,
                 N=1,
                 K=1,
                 D=1,
                 H=1,
                 W=1,
                 T=1,
                 R=1,
                 S=1,
                 pad_d=0,
                 pad_h=0,
                 pad_w=0,
                 str_d=1,
                 str_h=1,
                 str_w=1,
                 dil_d=1,
                 dil_h=1,
                 dil_w=1,
                 deconv=False):

        if deconv:
            M = deconv_output_dim(D, T, pad_d, str_d)
            P = deconv_output_dim(H, R, pad_h, str_h)
            Q = deconv_output_dim(W, S, pad_w, str_w)
        else:
            M = conv_output_dim(D, T, pad_d, str_d)
            P = conv_output_dim(H, R, pad_h, str_h)
            Q = conv_output_dim(W, S, pad_w, str_w)

        self.dimO = (K, M, P, Q, N)
        self.dimI = (C, D, H, W, N)
        if deconv:
            self.dimF = (K, T, R, S, C)
        else:
            self.dimF = (C, T, R, S, K)

        self.conv_params = dict(pad_d=pad_d,
                                pad_h=pad_h,
                                pad_w=pad_w,
                                str_d=str_d,
                                str_h=str_h,
                                str_w=str_w,
                                dil_d=dil_d,
                                dil_h=dil_h,
                                dil_w=dil_w)

        batch_axis = ng.make_axis(name='N', length=N)

        self.ax_i = ng.make_axes([
            ng.make_axis(name='C', length=C),
            ng.make_axis(name='D', length=D),
            ng.make_axis(name='H', length=H),
            ng.make_axis(name='W', length=W), batch_axis
        ])

        if deconv:
            self.ax_f = ng.make_axes([
                ng.make_axis(name='C', length=K),
                ng.make_axis(name='D', length=T),
                ng.make_axis(name='H', length=R),
                ng.make_axis(name='W', length=S),
                ng.make_axis(name='K', length=C),
            ])
        else:
            self.ax_f = ng.make_axes([
                ng.make_axis(name='C', length=C),
                ng.make_axis(name='D', length=T),
                ng.make_axis(name='H', length=R),
                ng.make_axis(name='W', length=S),
                ng.make_axis(name='K', length=K),
            ])

        self.ax_o = ng.make_axes([
            ng.make_axis(name='C', length=K),
            ng.make_axis(name='D', length=M),
            ng.make_axis(name='H', length=P),
            ng.make_axis(name='W', length=Q), batch_axis
        ])