예제 #1
0
    def get_output_shape_for(self, input_shape):
        if self.dim_ordering == 'th':
            rows = input_shape[2]
            cols = input_shape[3]
        elif self.dim_ordering == 'tf':
            rows = input_shape[1]
            cols = input_shape[2]
        else:
            raise ValueError('Invalid dim_ordering:', self.dim_ordering)
        if self.causal:
            rows += self.atrous_rate[0] * (self.nb_row - 1)

        rows = conv_output_length(rows,
                                  self.nb_row,
                                  self.border_mode,
                                  self.subsample[0],
                                  dilation=self.atrous_rate[0])
        cols = conv_output_length(cols,
                                  self.nb_col,
                                  self.border_mode,
                                  self.subsample[1],
                                  dilation=self.atrous_rate[1])
        if self.dim_ordering == 'th':
            return (input_shape[0], self.nb_filter, rows, cols)
        elif self.dim_ordering == 'tf':
            return (input_shape[0], rows, cols, self.nb_filter)
예제 #2
0
    def get_output_shape_for(self, input_shape):

        if self.dim_ordering == 'th':
            rows = input_shape[3]
            cols = input_shape[4]
        elif self.dim_ordering == 'tf':
            rows = input_shape[2]
            cols = input_shape[3]
        else:
            raise Exception('Invalid dim_ordering: ' + self.dim_ordering)

        rows = conv_output_length(rows, self.nb_row, self.border_mode,
                                  self.subsample[0])
        cols = conv_output_length(cols, self.nb_col, self.border_mode,
                                  self.subsample[1])

        if self.return_sequences:
            if self.dim_ordering == 'th':
                return (input_shape[0], input_shape[1], self.nb_filter, rows,
                        cols)
            elif self.dim_ordering == 'tf':
                return (input_shape[0], input_shape[1], rows, cols,
                        self.nb_filter)
            else:
                raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
        else:
            if self.dim_ordering == 'th':
                return (input_shape[0], self.nb_filter, rows, cols)
            elif self.dim_ordering == 'tf':
                return (input_shape[0], rows, cols, self.nb_filter)
            else:
                raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
예제 #3
0
    def get_output_shape_for(self, input_shape):
        if self.dim_ordering == 'tf':
            rows = input_shape[0][1]
            cols = input_shape[0][2]
        else:
            raise ValueError('Only support tensorflow.')

        rows = conv_output_length(rows, self.kernel_size, self.border_mode,
                                  self.subsample[0])
        cols = conv_output_length(cols, self.kernel_size, self.border_mode,
                                  self.subsample[1])

        return (input_shape[0][0], rows, cols, input_shape[0][-1])
예제 #4
0
 def get_output_shape_for(self, input_shape):
     length = input_shape[1]
     if length:
         length = conv_output_length(length + self.window_size - 1,
                                     self.window_size, 'valid',
                                     self.subsample[0])
     return (input_shape[0], length, self.output_dim)
예제 #5
0
 def get_output_shape_for(self, input_shape):
     length = conv_output_length(input_shape[1],
                                 self.filter_length,
                                 self.border_mode,
                                 self.subsample[0],
                                 dilation=self.atrous_rate)
     return (input_shape[0], length, self.nb_filter)
예제 #6
0
    def build(self, input_shape):
        '''input_shape: (n_ch, length)'''
        self.n_ch = input_shape[1]
        self.len_src = input_shape[2]
        self.is_mono = (self.n_ch == 1)
        if self.dim_ordering == 'th':
            self.ch_axis_idx = 1
        else:
            self.ch_axis_idx = 3
        assert self.len_src >= self.n_dft, 'Hey! The input is too short!'

        self.n_frame = conv_output_length(self.len_src, self.n_dft,
                                          self.padding, self.n_hop)

        dft_real_kernels, dft_imag_kernels = backend.get_stft_kernels(
            self.n_dft)
        self.dft_real_kernels = K.variable(dft_real_kernels, dtype=K.floatx())
        self.dft_imag_kernels = K.variable(dft_imag_kernels, dtype=K.floatx())
        # kernels shapes: (filter_length, 1, input_dim, nb_filter)?
        if self.trainable_kernel:
            self.trainable_weights.append(self.dft_real_kernels)
            self.trainable_weights.append(self.dft_imag_kernels)
        else:
            self.non_trainable_weights.append(self.dft_real_kernels)
            self.non_trainable_weights.append(self.dft_imag_kernels)

        super(Spectrogram, self).build(input_shape)
	def get_output_shape_for(self, input_shape):
		if self.dim_ordering == 'th':
			channels = input_shape[2]
			rows = input_shape[3]
			cols = input_shape[4]
		elif self.dim_ordering == 'tf':
			rows = input_shape[2]
			cols = input_shape[3]
			channels = input_shape[4]
		else:
			raise ValueError('Invalid dim_ordering:', self.dim_ordering)

		if self.spatial_pool is True:
			rows = conv_output_length(rows, self.pool_size[0], self.border_mode, self.strides[0])
			cols = conv_output_length(cols, self.pool_size[0], self.border_mode, self.strides[0])

		if self.dim_ordering == 'th':
			return (input_shape[0], channels, rows, cols)
		elif self.dim_ordering == 'tf':
			return (input_shape[0], rows, cols, channels)
예제 #8
0
def get_conv_output_shape(batch_input_shape, kernel_size, subsample,
                          dim_ordering, border_mode):
    if border_mode not in {'valid', 'same'}:
        raise Exception('Invalid border mode:', border_mode)
    if dim_ordering == 'default':
        dim_ordering = K.image_dim_ordering()
    if dim_ordering == 'th':
        nb_features = batch_input_shape[0]
        rows = batch_input_shape[1]
        cols = batch_input_shape[2]
    elif dim_ordering == 'tf':
        nb_features = batch_input_shape[2]
        rows = batch_input_shape[0]
        cols = batch_input_shape[1]
    else:
        raise Exception('Invalid dim_ordering: ' + dim_ordering)
    output_rows = np_utils.conv_output_length(rows, kernel_size[0],
                                              border_mode, subsample[0])
    output_cols = np_utils.conv_output_length(cols, kernel_size[1],
                                              border_mode, subsample[1])
    output_features = nb_features
    return [output_rows, output_cols, nb_features]
예제 #9
0
파일: stft.py 프로젝트: sudhirshahu51/kapre
 def build(self, input_shape):
     self.n_ch = input_shape[1]
     self.len_src = input_shape[2]
     self.is_mono = (self.n_ch == 1)
     if self.dim_ordering == 'th':
         self.ch_axis_idx = 1
     else:
         self.ch_axis_idx = 3
     assert self.len_src >= self.n_fft, 'Hey! The input is too short!'
     self.n_frame = conv_output_length(self.len_src, self.n_fft, 'valid',
                                       self.n_hop)
     self.fft_window = backend._hann(self.n_fft, sym=False)
     self.built = True
 def get_output_shape_for(self, input_shape):
     length = conv_output_length(input_shape[1], self.pool_length,
                                 self.border_mode, self.stride)
     return (input_shape[0], length, input_shape[2])