Exemplo n.º 1
0
    def build(self, input_shape):
        if self.data_format == 'channels_first':
            input_dim, input_length = input_shape[1], input_shape[2]
        else:
            input_dim, input_length = input_shape[2], input_shape[1]

        if input_dim is None:
            raise ValueError(
                'Axis 2 of input should be fully-defined. '
                'Found shape:', input_shape)
        self.output_length = conv_utils.conv_output_length(
            input_length, self.kernel_size[0], self.padding, self.strides[0])

        if self.implementation == 1:
            self.kernel_shape = (self.output_length,
                                 self.kernel_size[0] * input_dim, self.filters)

            self.kernel = self.add_weight(shape=self.kernel_shape,
                                          initializer=self.kernel_initializer,
                                          name='kernel',
                                          regularizer=self.kernel_regularizer,
                                          constraint=self.kernel_constraint)

        elif self.implementation == 2:
            if self.data_format == 'channels_first':
                self.kernel_shape = (input_dim, input_length, self.filters,
                                     self.output_length)
            else:
                self.kernel_shape = (input_length, input_dim,
                                     self.output_length, self.filters)

            self.kernel = self.add_weight(shape=self.kernel_shape,
                                          initializer=self.kernel_initializer,
                                          name='kernel',
                                          regularizer=self.kernel_regularizer,
                                          constraint=self.kernel_constraint)

            self.kernel_mask = get_locallyconnected_mask(
                input_shape=(input_length, ),
                kernel_shape=self.kernel_size,
                strides=self.strides,
                padding=self.padding,
                data_format=self.data_format,
            )

        elif self.implementation == 3:
            self.kernel_shape = (self.output_length * self.filters,
                                 input_length * input_dim)

            self.kernel_idxs = sorted(
                conv_utils.conv_kernel_idxs(input_shape=(input_length, ),
                                            kernel_shape=self.kernel_size,
                                            strides=self.strides,
                                            padding=self.padding,
                                            filters_in=input_dim,
                                            filters_out=self.filters,
                                            data_format=self.data_format))

            self.kernel = self.add_weight(shape=(len(self.kernel_idxs), ),
                                          initializer=self.kernel_initializer,
                                          name='kernel',
                                          regularizer=self.kernel_regularizer,
                                          constraint=self.kernel_constraint)

        else:
            raise ValueError('Unrecognized implementation mode: %d.' %
                             self.implementation)

        if self.use_bias:
            self.bias = self.add_weight(shape=(self.output_length,
                                               self.filters),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None

        if self.data_format == 'channels_first':
            self.input_spec = InputSpec(ndim=3, axes={1: input_dim})
        else:
            self.input_spec = InputSpec(ndim=3, axes={-1: input_dim})
        self.built = True
Exemplo n.º 2
0
    def build(self, input_shape):
        if self.data_format == 'channels_last':
            input_row, input_col = input_shape[1:-1]
            input_filter = input_shape[3]
        else:
            input_row, input_col = input_shape[2:]
            input_filter = input_shape[1]
        if input_row is None or input_col is None:
            raise ValueError('The spatial dimensions of the inputs to '
                             ' a LocallyConnected2D layer '
                             'should be fully-defined, but layer received '
                             'the inputs shape ' + str(input_shape))
        output_row = conv_utils.conv_output_length(input_row,
                                                   self.kernel_size[0],
                                                   self.padding,
                                                   self.strides[0])
        output_col = conv_utils.conv_output_length(input_col,
                                                   self.kernel_size[1],
                                                   self.padding,
                                                   self.strides[1])
        self.output_row = output_row
        self.output_col = output_col

        if self.implementation == 1:
            self.kernel_shape = (output_row * output_col, self.kernel_size[0] *
                                 self.kernel_size[1] * input_filter,
                                 self.filters)

            self.kernel = self.add_weight(shape=self.kernel_shape,
                                          initializer=self.kernel_initializer,
                                          name='kernel',
                                          regularizer=self.kernel_regularizer,
                                          constraint=self.kernel_constraint)

        elif self.implementation == 2:
            if self.data_format == 'channels_first':
                self.kernel_shape = (input_filter, input_row, input_col,
                                     self.filters, self.output_row,
                                     self.output_col)
            else:
                self.kernel_shape = (input_row, input_col, input_filter,
                                     self.output_row, self.output_col,
                                     self.filters)

            self.kernel = self.add_weight(shape=self.kernel_shape,
                                          initializer=self.kernel_initializer,
                                          name='kernel',
                                          regularizer=self.kernel_regularizer,
                                          constraint=self.kernel_constraint)

            self.kernel_mask = get_locallyconnected_mask(
                input_shape=(input_row, input_col),
                kernel_shape=self.kernel_size,
                strides=self.strides,
                padding=self.padding,
                data_format=self.data_format,
            )

        elif self.implementation == 3:
            self.kernel_shape = (self.output_row * self.output_col *
                                 self.filters,
                                 input_row * input_col * input_filter)

            self.kernel_idxs = sorted(
                conv_utils.conv_kernel_idxs(input_shape=(input_row, input_col),
                                            kernel_shape=self.kernel_size,
                                            strides=self.strides,
                                            padding=self.padding,
                                            filters_in=input_filter,
                                            filters_out=self.filters,
                                            data_format=self.data_format))

            self.kernel = self.add_weight(shape=(len(self.kernel_idxs), ),
                                          initializer=self.kernel_initializer,
                                          name='kernel',
                                          regularizer=self.kernel_regularizer,
                                          constraint=self.kernel_constraint)

        else:
            raise ValueError('Unrecognized implementation mode: %d.' %
                             self.implementation)

        if self.use_bias:
            self.bias = self.add_weight(shape=(output_row, output_col,
                                               self.filters),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
        if self.data_format == 'channels_first':
            self.input_spec = InputSpec(ndim=4, axes={1: input_filter})
        else:
            self.input_spec = InputSpec(ndim=4, axes={-1: input_filter})
        self.built = True