Ejemplo n.º 1
0
def test_conv_inverts_deconv(input_size, filter_size, padding, stride):
    """ Test that conv and deconv are inverse operations given the same parameters"""

    # convolutions whose output size are not an even multiple of stride cannot be exactly inverted
    a = (input_size + sum(padding) - filter_size) % stride
    conv_output = utils.conv_output_dim(input_size, filter_size, padding,
                                        stride)
    deconv_output = utils.deconv_output_dim(conv_output, filter_size, padding,
                                            stride)

    assert deconv_output == (input_size - a), (
        "Convolution and Deconvolution do not invert:\n"
        "output ({}) != input ({}) - a ({})\n"
        "filter: {}, padding: {}, stride: {}").format(deconv_output,
                                                      input_size, a,
                                                      filter_size, padding,
                                                      stride)
Ejemplo n.º 2
0
    def __init__(self,
                 C=1,
                 N=1,
                 K=1,
                 D=1,
                 H=1,
                 W=1,
                 T=1,
                 R=1,
                 S=1,
                 pad_d=0,
                 pad_h=0,
                 pad_w=0,
                 str_d=1,
                 str_h=1,
                 str_w=1,
                 dil_d=1,
                 dil_h=1,
                 dil_w=1,
                 deconv=False):

        if deconv:
            M = deconv_output_dim(D, T, pad_d, str_d)
            P = deconv_output_dim(H, R, pad_h, str_h)
            Q = deconv_output_dim(W, S, pad_w, str_w)
        else:
            M = conv_output_dim(D, T, pad_d, str_d)
            P = conv_output_dim(H, R, pad_h, str_h)
            Q = conv_output_dim(W, S, pad_w, str_w)

        self.dimO = (K, M, P, Q, N)
        self.dimI = (C, D, H, W, N)
        if deconv:
            self.dimF = (K, T, R, S, C)
        else:
            self.dimF = (C, T, R, S, K)

        self.conv_params = dict(pad_d=pad_d,
                                pad_h=pad_h,
                                pad_w=pad_w,
                                str_d=str_d,
                                str_h=str_h,
                                str_w=str_w,
                                dil_d=dil_d,
                                dil_h=dil_h,
                                dil_w=dil_w)

        batch_axis = ng.make_axis(name='N', length=N)

        self.ax_i = ng.make_axes([
            ng.make_axis(name='C', length=C),
            ng.make_axis(name='D', length=D),
            ng.make_axis(name='H', length=H),
            ng.make_axis(name='W', length=W), batch_axis
        ])

        if deconv:
            self.ax_f = ng.make_axes([
                ng.make_axis(name='C', length=K),
                ng.make_axis(name='D', length=T),
                ng.make_axis(name='H', length=R),
                ng.make_axis(name='W', length=S),
                ng.make_axis(name='K', length=C),
            ])
        else:
            self.ax_f = ng.make_axes([
                ng.make_axis(name='C', length=C),
                ng.make_axis(name='D', length=T),
                ng.make_axis(name='H', length=R),
                ng.make_axis(name='W', length=S),
                ng.make_axis(name='K', length=K),
            ])

        self.ax_o = ng.make_axes([
            ng.make_axis(name='C', length=K),
            ng.make_axis(name='D', length=M),
            ng.make_axis(name='H', length=P),
            ng.make_axis(name='W', length=Q), batch_axis
        ])