def __init__(self, C=1, N=1, K=1, D=1, H=1, W=1, T=1, R=1, S=1, pad_d=0, pad_h=0, pad_w=0, str_d=1, str_h=1, str_w=1): from ngraph.frontends.common.utils import conv_output_dim M = conv_output_dim(D, T, pad_d, str_d) P = conv_output_dim(H, R, pad_h, str_h) Q = conv_output_dim(W, S, pad_w, str_w) self.dimO = (K, M, P, Q, N) self.dimI = (C, D, H, W, N) self.dimF = (C, T, R, S, K) self.conv_params = dict(pad_d=pad_d, pad_h=pad_h, pad_w=pad_w, str_d=str_d, str_h=str_h, str_w=str_w, dil_d=1, dil_h=1, dil_w=1) self.batch_axis = ng.make_axis(name='N', length=N) self.ax_i = ng.make_axes([ ng.make_axis(name='C', length=C), ng.make_axis(name='D', length=D), ng.make_axis(name='H', length=H), ng.make_axis(name='W', length=W), self.batch_axis ]) self.ax_f = ng.make_axes([ ng.make_axis(name='C', length=C), ng.make_axis(name='D', length=T), ng.make_axis(name='H', length=R), ng.make_axis(name='W', length=S), ng.make_axis(name='K', length=K), ]) self.ax_o = ng.make_axes([ ng.make_axis(name='C', length=K), ng.make_axis(name='D', length=M), ng.make_axis(name='H', length=P), ng.make_axis(name='W', length=Q), self.batch_axis ])
def __init__(self, C=1, N=1, D=1, H=1, W=1, J=1, T=1, R=1, S=1, pad_c=0, pad_d=0, pad_h=0, pad_w=0, str_c=1, str_d=1, str_h=1, str_w=1, op='max'): K = conv_output_dim(C, J, pad_c, str_c) M = conv_output_dim(D, T, pad_d, str_d) P = conv_output_dim(H, R, pad_h, str_h) Q = conv_output_dim(W, S, pad_w, str_w) self.dimO = (K, M, P, Q, N) self.dimI = (C, D, H, W, N) self.dimF = (J, T, R, S, K) self.pool_params = dict(pad_c=pad_c, pad_d=pad_d, pad_h=pad_h, pad_w=pad_w, str_c=str_c, str_d=str_d, str_h=str_h, str_w=str_w, J=J, T=T, R=R, S=S, op=op) batch_axis = ng.make_axis(name='N', length=N) self.ax_i = ng.make_axes([ ng.make_axis(name='C', length=C), ng.make_axis(name='D', length=D), ng.make_axis(name='H', length=H), ng.make_axis(name='W', length=W), batch_axis ]) self.ax_o = ng.make_axes([ ng.make_axis(name='C', length=K), ng.make_axis(name='D', length=M), ng.make_axis(name='H', length=P), ng.make_axis(name='W', length=Q), batch_axis ])
def test_conv_inverts_deconv(input_size, filter_size, padding, stride): """ Test that conv and deconv are inverse operations given the same parameters""" # convolutions whose output size are not an even multiple of stride cannot be exactly inverted a = (input_size + sum(padding) - filter_size) % stride conv_output = utils.conv_output_dim(input_size, filter_size, padding, stride) deconv_output = utils.deconv_output_dim(conv_output, filter_size, padding, stride) assert deconv_output == (input_size - a), ( "Convolution and Deconvolution do not invert:\n" "output ({}) != input ({}) - a ({})\n" "filter: {}, padding: {}, stride: {}").format(deconv_output, input_size, a, filter_size, padding, stride)
def __init__(self, C=1, N=1, K=1, D=1, H=1, W=1, T=1, R=1, S=1, pad_d=0, pad_h=0, pad_w=0, str_d=1, str_h=1, str_w=1, dil_d=1, dil_h=1, dil_w=1, deconv=False): if deconv: M = deconv_output_dim(D, T, pad_d, str_d) P = deconv_output_dim(H, R, pad_h, str_h) Q = deconv_output_dim(W, S, pad_w, str_w) else: M = conv_output_dim(D, T, pad_d, str_d) P = conv_output_dim(H, R, pad_h, str_h) Q = conv_output_dim(W, S, pad_w, str_w) self.dimO = (K, M, P, Q, N) self.dimI = (C, D, H, W, N) if deconv: self.dimF = (K, T, R, S, C) else: self.dimF = (C, T, R, S, K) self.conv_params = dict(pad_d=pad_d, pad_h=pad_h, pad_w=pad_w, str_d=str_d, str_h=str_h, str_w=str_w, dil_d=dil_d, dil_h=dil_h, dil_w=dil_w) batch_axis = ng.make_axis(name='N', length=N) self.ax_i = ng.make_axes([ ng.make_axis(name='C', length=C), ng.make_axis(name='D', length=D), ng.make_axis(name='H', length=H), ng.make_axis(name='W', length=W), batch_axis ]) if deconv: self.ax_f = ng.make_axes([ ng.make_axis(name='C', length=K), ng.make_axis(name='D', length=T), ng.make_axis(name='H', length=R), ng.make_axis(name='W', length=S), ng.make_axis(name='K', length=C), ]) else: self.ax_f = ng.make_axes([ ng.make_axis(name='C', length=C), ng.make_axis(name='D', length=T), ng.make_axis(name='H', length=R), ng.make_axis(name='W', length=S), ng.make_axis(name='K', length=K), ]) self.ax_o = ng.make_axes([ ng.make_axis(name='C', length=K), ng.make_axis(name='D', length=M), ng.make_axis(name='H', length=P), ng.make_axis(name='W', length=Q), batch_axis ])