示例#1
0
    def build(self, input_shape):
        if self.data_format == 'channels_first':
            input_dim, input_length = input_shape[1], input_shape[2]
        else:
            input_dim, input_length = input_shape[2], input_shape[1]

        if input_dim is None:
            raise ValueError(
                'Axis 2 of input should be fully-defined. '
                'Found shape:', input_shape)
        self.output_length = conv_utils.conv_output_length(
            input_length, self.kernel_size[0], self.padding, self.strides[0])

        if self.output_length <= 0:
            raise ValueError(
                f'One of the dimensions in the output is <= 0 '
                f'due to downsampling in {self.name}. Consider '
                f'increasing the input size. '
                f'Received input shape {input_shape} which would produce '
                f'output shape with a zero or negative value in a '
                f'dimension.')

        if self.implementation == 1:
            self.kernel_shape = (self.output_length,
                                 self.kernel_size[0] * input_dim, self.filters)

            self.kernel = self.add_weight(shape=self.kernel_shape,
                                          initializer=self.kernel_initializer,
                                          name='kernel',
                                          regularizer=self.kernel_regularizer,
                                          constraint=self.kernel_constraint)

        elif self.implementation == 2:
            if self.data_format == 'channels_first':
                self.kernel_shape = (input_dim, input_length, self.filters,
                                     self.output_length)
            else:
                self.kernel_shape = (input_length, input_dim,
                                     self.output_length, self.filters)

            self.kernel = self.add_weight(shape=self.kernel_shape,
                                          initializer=self.kernel_initializer,
                                          name='kernel',
                                          regularizer=self.kernel_regularizer,
                                          constraint=self.kernel_constraint)

            self.kernel_mask = locally_connected_utils.get_locallyconnected_mask(
                input_shape=(input_length, ),
                kernel_shape=self.kernel_size,
                strides=self.strides,
                padding=self.padding,
                data_format=self.data_format,
            )

        elif self.implementation == 3:
            self.kernel_shape = (self.output_length * self.filters,
                                 input_length * input_dim)

            self.kernel_idxs = sorted(
                conv_utils.conv_kernel_idxs(input_shape=(input_length, ),
                                            kernel_shape=self.kernel_size,
                                            strides=self.strides,
                                            padding=self.padding,
                                            filters_in=input_dim,
                                            filters_out=self.filters,
                                            data_format=self.data_format))

            self.kernel = self.add_weight(shape=(len(self.kernel_idxs), ),
                                          initializer=self.kernel_initializer,
                                          name='kernel',
                                          regularizer=self.kernel_regularizer,
                                          constraint=self.kernel_constraint)

        else:
            raise ValueError('Unrecognized implementation mode: %d.' %
                             self.implementation)

        if self.use_bias:
            self.bias = self.add_weight(shape=(self.output_length,
                                               self.filters),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None

        if self.data_format == 'channels_first':
            self.input_spec = InputSpec(ndim=3, axes={1: input_dim})
        else:
            self.input_spec = InputSpec(ndim=3, axes={-1: input_dim})
        self.built = True
示例#2
0
    def build(self, input_shape):
        if self.data_format == 'channels_first':
            input_dim, input_length = input_shape[1], input_shape[2]
        else:
            input_dim, input_length = input_shape[2], input_shape[1]

        if input_dim is None:
            raise ValueError(
                'Axis 2 of input should be fully-defined. '
                'Found shape:', input_shape)
        self.output_length = conv_utils.conv_output_length(
            input_length, self.kernel_size[0], self.padding, self.strides[0])

        if self.implementation == 1:
            self.kernel_shape = (self.output_length,
                                 self.kernel_size[0] * input_dim, self.filters)

            self.kernel = self.add_weight(shape=self.kernel_shape,
                                          initializer=self.kernel_initializer,
                                          name='kernel',
                                          regularizer=self.kernel_regularizer,
                                          constraint=self.kernel_constraint)

        elif self.implementation == 2:
            if self.data_format == 'channels_first':
                self.kernel_shape = (input_dim, input_length, self.filters,
                                     self.output_length)
            else:
                self.kernel_shape = (input_length, input_dim,
                                     self.output_length, self.filters)

            self.kernel = self.add_weight(shape=self.kernel_shape,
                                          initializer=self.kernel_initializer,
                                          name='kernel',
                                          regularizer=self.kernel_regularizer,
                                          constraint=self.kernel_constraint)

            self.kernel_mask = get_locallyconnected_mask(
                input_shape=(input_length, ),
                kernel_shape=self.kernel_size,
                strides=self.strides,
                padding=self.padding,
                data_format=self.data_format,
            )

        elif self.implementation == 3:
            self.kernel_shape = (self.output_length * self.filters,
                                 input_length * input_dim)

            self.kernel_idxs = sorted(
                conv_utils.conv_kernel_idxs(input_shape=(input_length, ),
                                            kernel_shape=self.kernel_size,
                                            strides=self.strides,
                                            padding=self.padding,
                                            filters_in=input_dim,
                                            filters_out=self.filters,
                                            data_format=self.data_format))

            self.kernel = self.add_weight(shape=(len(self.kernel_idxs), ),
                                          initializer=self.kernel_initializer,
                                          name='kernel',
                                          regularizer=self.kernel_regularizer,
                                          constraint=self.kernel_constraint)

        else:
            raise ValueError('Unrecognized implementation mode: %d.' %
                             self.implementation)

        if self.use_bias:
            self.bias = self.add_weight(shape=(self.output_length,
                                               self.filters),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None

        if self.data_format == 'channels_first':
            self.input_spec = InputSpec(ndim=3, axes={1: input_dim})
        else:
            self.input_spec = InputSpec(ndim=3, axes={-1: input_dim})
        self.built = True
示例#3
0
    def build(self, input_shape):
        if self.data_format == 'channels_last':
            input_row, input_col = input_shape[1:-1]
            input_filter = input_shape[3]
        else:
            input_row, input_col = input_shape[2:]
            input_filter = input_shape[1]
        if input_row is None or input_col is None:
            raise ValueError('The spatial dimensions of the inputs to '
                             ' a LocallyConnected2D layer '
                             'should be fully-defined, but layer received '
                             'the inputs shape ' + str(input_shape))
        output_row = conv_utils.conv_output_length(input_row,
                                                   self.kernel_size[0],
                                                   self.padding,
                                                   self.strides[0])
        output_col = conv_utils.conv_output_length(input_col,
                                                   self.kernel_size[1],
                                                   self.padding,
                                                   self.strides[1])
        self.output_row = output_row
        self.output_col = output_col

        if self.implementation == 1:
            self.kernel_shape = (output_row * output_col, self.kernel_size[0] *
                                 self.kernel_size[1] * input_filter,
                                 self.filters)

            self.kernel = self.add_weight(shape=self.kernel_shape,
                                          initializer=self.kernel_initializer,
                                          name='kernel',
                                          regularizer=self.kernel_regularizer,
                                          constraint=self.kernel_constraint)

        elif self.implementation == 2:
            if self.data_format == 'channels_first':
                self.kernel_shape = (input_filter, input_row, input_col,
                                     self.filters, self.output_row,
                                     self.output_col)
            else:
                self.kernel_shape = (input_row, input_col, input_filter,
                                     self.output_row, self.output_col,
                                     self.filters)

            self.kernel = self.add_weight(shape=self.kernel_shape,
                                          initializer=self.kernel_initializer,
                                          name='kernel',
                                          regularizer=self.kernel_regularizer,
                                          constraint=self.kernel_constraint)

            self.kernel_mask = get_locallyconnected_mask(
                input_shape=(input_row, input_col),
                kernel_shape=self.kernel_size,
                strides=self.strides,
                padding=self.padding,
                data_format=self.data_format,
            )

        elif self.implementation == 3:
            self.kernel_shape = (self.output_row * self.output_col *
                                 self.filters,
                                 input_row * input_col * input_filter)

            self.kernel_idxs = sorted(
                conv_utils.conv_kernel_idxs(input_shape=(input_row, input_col),
                                            kernel_shape=self.kernel_size,
                                            strides=self.strides,
                                            padding=self.padding,
                                            filters_in=input_filter,
                                            filters_out=self.filters,
                                            data_format=self.data_format))

            self.kernel = self.add_weight(shape=(len(self.kernel_idxs), ),
                                          initializer=self.kernel_initializer,
                                          name='kernel',
                                          regularizer=self.kernel_regularizer,
                                          constraint=self.kernel_constraint)

        else:
            raise ValueError('Unrecognized implementation mode: %d.' %
                             self.implementation)

        if self.use_bias:
            self.bias = self.add_weight(shape=(output_row, output_col,
                                               self.filters),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
        if self.data_format == 'channels_first':
            self.input_spec = InputSpec(ndim=4, axes={1: input_filter})
        else:
            self.input_spec = InputSpec(ndim=4, axes={-1: input_filter})
        self.built = True
示例#4
0
    def build(self, input_shape):
        if self.data_format == "channels_last":
            input_row, input_col = input_shape[1:-1]
            input_filter = input_shape[3]
        else:
            input_row, input_col = input_shape[2:]
            input_filter = input_shape[1]
        if input_row is None or input_col is None:
            raise ValueError("The spatial dimensions of the inputs to "
                             " a LocallyConnected2D layer "
                             "should be fully-defined, but layer received "
                             "the inputs shape " + str(input_shape))
        output_row = conv_utils.conv_output_length(input_row,
                                                   self.kernel_size[0],
                                                   self.padding,
                                                   self.strides[0])
        output_col = conv_utils.conv_output_length(input_col,
                                                   self.kernel_size[1],
                                                   self.padding,
                                                   self.strides[1])
        self.output_row = output_row
        self.output_col = output_col

        if self.output_row <= 0 or self.output_col <= 0:
            raise ValueError(
                f"One of the dimensions in the output is <= 0 "
                f"due to downsampling in {self.name}. Consider "
                f"increasing the input size. "
                f"Received input shape {input_shape} which would produce "
                f"output shape with a zero or negative value in a "
                f"dimension.")

        if self.implementation == 1:
            self.kernel_shape = (
                output_row * output_col,
                self.kernel_size[0] * self.kernel_size[1] * input_filter,
                self.filters,
            )

            self.kernel = self.add_weight(
                shape=self.kernel_shape,
                initializer=self.kernel_initializer,
                name="kernel",
                regularizer=self.kernel_regularizer,
                constraint=self.kernel_constraint,
            )

        elif self.implementation == 2:
            if self.data_format == "channels_first":
                self.kernel_shape = (
                    input_filter,
                    input_row,
                    input_col,
                    self.filters,
                    self.output_row,
                    self.output_col,
                )
            else:
                self.kernel_shape = (
                    input_row,
                    input_col,
                    input_filter,
                    self.output_row,
                    self.output_col,
                    self.filters,
                )

            self.kernel = self.add_weight(
                shape=self.kernel_shape,
                initializer=self.kernel_initializer,
                name="kernel",
                regularizer=self.kernel_regularizer,
                constraint=self.kernel_constraint,
            )

            self.kernel_mask = (
                locally_connected_utils.get_locallyconnected_mask(
                    input_shape=(input_row, input_col),
                    kernel_shape=self.kernel_size,
                    strides=self.strides,
                    padding=self.padding,
                    data_format=self.data_format,
                ))

        elif self.implementation == 3:
            self.kernel_shape = (
                self.output_row * self.output_col * self.filters,
                input_row * input_col * input_filter,
            )

            self.kernel_idxs = sorted(
                conv_utils.conv_kernel_idxs(
                    input_shape=(input_row, input_col),
                    kernel_shape=self.kernel_size,
                    strides=self.strides,
                    padding=self.padding,
                    filters_in=input_filter,
                    filters_out=self.filters,
                    data_format=self.data_format,
                ))

            self.kernel = self.add_weight(
                shape=(len(self.kernel_idxs), ),
                initializer=self.kernel_initializer,
                name="kernel",
                regularizer=self.kernel_regularizer,
                constraint=self.kernel_constraint,
            )

        else:
            raise ValueError("Unrecognized implementation mode: %d." %
                             self.implementation)

        if self.use_bias:
            self.bias = self.add_weight(
                shape=(output_row, output_col, self.filters),
                initializer=self.bias_initializer,
                name="bias",
                regularizer=self.bias_regularizer,
                constraint=self.bias_constraint,
            )
        else:
            self.bias = None
        if self.data_format == "channels_first":
            self.input_spec = InputSpec(ndim=4, axes={1: input_filter})
        else:
            self.input_spec = InputSpec(ndim=4, axes={-1: input_filter})
        self.built = True