示例#1
0
    def get_output_shape_for(self, input_shape):
        if self.dim_ordering == 'th':
            rows = input_shape[2 + 1]
            cols = input_shape[3 + 1]
        elif self.dim_ordering == 'tf':
            rows = input_shape[1 + 1]
            cols = input_shape[2 + 1]
        else:
            raise Exception('Invalid dim_ordering: ' + self.dim_ordering)

        rows = conv_output_length(rows, self.nb_row, self.border_mode,
                                  self.subsample[0])
        cols = conv_output_length(cols, self.nb_col, self.border_mode,
                                  self.subsample[1])

        if self.return_sequences:
            if self.dim_ordering == 'th':
                return (input_shape[0], input_shape[1], self.nb_filter, rows,
                        cols)
            elif self.dim_ordering == 'tf':
                return (input_shape[0], input_shape[1], rows, cols,
                        self.nb_filter)
            else:
                raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
        else:
            if self.dim_ordering == 'th':
                return (input_shape[0], self.nb_filter, rows, cols)
            elif self.dim_ordering == 'tf':
                return (input_shape[0], rows, cols, self.nb_filter)
            else:
                raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
示例#2
0
文件: coding.py 项目: t13m/seya
 def output_shape(self):
     input_shape = self.input_shape
     out_row = conv_output_length(input_shape[2], self.nb_row,
                                  self.border_mode, 1)
     out_col = conv_output_length(input_shape[3], self.nb_col,
                                  self.border_mode, 1)
     return None, self.stack_size, out_row, out_col
示例#3
0
文件: coding.py 项目: jfsantos/seya
 def output_shape(self):
     input_shape = self.input_shape
     out_row = conv_output_length(input_shape[2], self.nb_row,
                                  self.border_mode, 1)
     out_col = conv_output_length(input_shape[3], self.nb_col,
                                  self.border_mode, 1)
     return None, self.stack_size, out_row, out_col
    def get_output_shape_for(self, input_shape):
        if self.dim_ordering == 'th':
            rows = input_shape[2+1]
            cols = input_shape[3+1]
        elif self.dim_ordering == 'tf':
            rows = input_shape[1+1]
            cols = input_shape[2+1]
        else:
            raise Exception('Invalid dim_ordering: ' + self.dim_ordering)

        rows = conv_output_length(rows, self.nb_row,
                                  self.border_mode, self.subsample[0])
        cols = conv_output_length(cols, self.nb_col,
                                  self.border_mode, self.subsample[1])

        if self.return_sequences:
            if self.dim_ordering == 'th':
                return (input_shape[0], input_shape[1],
                        self.nb_filter, rows, cols)
            elif self.dim_ordering == 'tf':
                return (input_shape[0], input_shape[1],
                        rows, cols, self.nb_filter)
            else:
                raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
        else:
            if self.dim_ordering == 'th':
                return (input_shape[0], self.nb_filter, rows, cols)
            elif self.dim_ordering == 'tf':
                return (input_shape[0], rows, cols, self.nb_filter)
            else:
                raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
示例#5
0
    def _get_output_dim(self, input_shape):
        if self.dim_ordering == 'th':
            rows = self.reshape_dim[2]
            cols = self.reshape_dim[3]
        elif self.dim_ordering == 'tf':
            rows = self.reshape_dim[1]
            cols = self.reshape_dim[2]
        else:
            raise Exception('Invalid dim_ordering: ' + self.dim_ordering)

        rows = conv_output_length(rows, self.nb_row, self.border_mode,
                                  self.subsample[0])
        cols = conv_output_length(cols, self.nb_col, self.border_mode,
                                  self.subsample[1])

        if self.dim_ordering == 'th':
            return (input_shape[0], self.nb_filter, rows, cols)
        elif self.dim_ordering == 'tf':
            return (input_shape[0], rows, cols, self.nb_filter)
        else:
            raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
示例#6
0
    def get_output_shape_for(
            self, input_shape):  #!!! called compute_output_shape in Keras 2?
        if self.dim_ordering == 'th':
            rows = input_shape[2]
            cols = input_shape[3]
        elif self.dim_ordering == 'tf':
            rows = input_shape[1]
            cols = input_shape[2]
        else:
            raise Exception('Invalid dim_ordering: ' + self.dim_ordering)

        rows = K_conv.conv_output_length(rows, self.nb_row, self.border_mode,
                                         self.subsample[0])
        cols = K_conv.conv_output_length(cols, self.nb_col, self.border_mode,
                                         self.subsample[1])

        if self.dim_ordering == 'th':
            return (input_shape[0], self.nb_filter, rows, cols)
        elif self.dim_ordering == 'tf':
            return (input_shape[0], rows, cols, self.nb_filter)
        else:
            raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
    def output_shape(self):
        input_shape = self.input_shape
        if self.dim_ordering == 'th':
            rows = input_shape[2]
            cols = input_shape[3]
        elif self.dim_ordering == 'tf':
            rows = input_shape[1]
            cols = input_shape[2]
        else:
            raise Exception('Invalid dim_ordering: ' + self.dim_ordering)

        rows = conv_output_length(rows, self.nb_row,
                                  self.border_mode, self.subsample[0])
        cols = conv_output_length(cols, self.nb_col,
                                  self.border_mode, self.subsample[1])

        if self.dim_ordering == 'th':
            return (input_shape[0], self.nb_filter, rows, cols)
        elif self.dim_ordering == 'tf':
            return (input_shape[0], rows, cols, self.nb_filter)
        else:
            raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
示例#8
0
    def output_shape(self):
        input_shape = self.input_shape
        if self.dim_ordering == 'th':
            rows = input_shape[2]
            cols = input_shape[3]
        elif self.dim_ordering == 'tf':
            rows = input_shape[1]
            cols = input_shape[2]
        else:
            raise Exception('Invalid dim_ordering: ' + self.dim_ordering)

        rows = conv_output_length(rows, self.nb_row,
                                  'same', self.strides[0])
        cols = conv_output_length(cols, self.nb_col,
                                  'same', self.strides[1])

        if self.dim_ordering == 'th':
            return (input_shape[0], self.nb_filter, rows, cols)
        elif self.dim_ordering == 'tf':
            return (input_shape[0], rows, cols, self.nb_filter)
        else:
            raise Exception('Invalid dim_ordering: ' + self.dim_ordering)
示例#9
0
 def get_output_shape_for(self, input_shape):
     length = conv_output_length(input_shape[1],
                                 self.filter_length,
                                 self.border_mode,
                                 self.subsample[0])
     return (input_shape[0], length, self.nb_filter)
示例#10
0
    def output_shape(self):

        input_shape = self.input_shape
        cols = input_shape[3]
        cols = conv_output_length(cols, self.nb_col, self.border_mode, self.subsample[1])
        return (input_shape[0], self.nb_filter, cols)