예제 #1
0
 def compute_output_shape(self, input_shape):
     input_shape = input_shape[0]
     if self.data_format == 'channels_last':
         space = input_shape[1:-1]
         new_space = []
         for i in range(len(space)):
             new_dim = conv_utils.conv_output_length(
                 space[i],
                 self.kernel_size[i],
                 padding=self.padding,
                 stride=self.strides[i],
                 dilation=self.dilation_rate[i])
             new_space.append(new_dim)
         return (input_shape[0],) + tuple(new_space) + (self.filters,)
     if self.data_format == 'channels_first':
         space = input_shape[2:]
         new_space = []
         for i in range(len(space)):
             new_dim = conv_utils.conv_output_length(
                 space[i],
                 self.kernel_size[i],
                 padding=self.padding,
                 stride=self.strides[i],
                 dilation=self.dilation_rate[i])
             new_space.append(new_dim)
         return (input_shape[0], self.filters) + tuple(new_space)
예제 #2
0
 def compute_output_shape(self, input_shape):
     input_shape = tensor_shape.TensorShape(input_shape).as_list()
     if self.data_format == "channels_last":
         space = input_shape[1:-1]
         new_space = []
         for i in range(len(space)):
             new_dim = conv_utils.conv_output_length(
                 space[i],
                 self.kernel_size[i],
                 padding=self.padding,
                 stride=self.strides[i],
                 dilation=self.dilation_rate[i],
             )
             new_space.append(new_dim)
         return tensor_shape.TensorShape([input_shape[0]] + new_space +
                                         [self.filters])
     else:
         space = input_shape[2:]
         new_space = []
         for i in range(len(space)):
             new_dim = conv_utils.conv_output_length(
                 space[i],
                 self.kernel_size[i],
                 padding=self.padding,
                 stride=self.strides[i],
                 dilation=self.dilation_rate[i],
             )
             new_space.append(new_dim)
         return tensor_shape.TensorShape([input_shape[0], self.filters] +
                                         new_space)
예제 #3
0
    def compute_output_shape(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape).as_list()
        if self.data_format == 'channels_first':
            rows = input_shape[2]
            cols = input_shape[3]
        else:
            rows = input_shape[1]
            cols = input_shape[2]

        # TODO: workaround! padding = 'same' shapes do not match
        _padding = self.padding
        self.padding = 'valid'

        rows = conv_utils.conv_output_length(rows,
                                             self.pool_size[0],
                                             padding=self.padding,
                                             stride=self.strides[0],
                                             dilation=self.dilation_rate[0])

        cols = conv_utils.conv_output_length(cols,
                                             self.pool_size[1],
                                             padding=self.padding,
                                             stride=self.strides[1],
                                             dilation=self.dilation_rate[1])

        # END workaround
        self.padding = _padding

        if self.data_format == 'channels_first':
            output_shape = (input_shape[0], input_shape[1], rows, cols)
        else:
            output_shape = (input_shape[0], rows, cols, input_shape[3])

        return tensor_shape.TensorShape(output_shape)
예제 #4
0
    def compute_output_shape(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape).as_list()
        if self.data_format == 'channels_first':
            rows = input_shape[2]
            cols = input_shape[3]
        else:
            rows = input_shape[1]
            cols = input_shape[2]

        rows = conv_utils.conv_output_length(rows,
                                             self.pool_size[0],
                                             padding=self.padding,
                                             stride=self.strides[0],
                                             dilation=self.dilation_rate)

        cols = conv_utils.conv_output_length(cols,
                                             self.pool_size[1],
                                             padding=self.padding,
                                             stride=self.strides[1],
                                             dilation=self.dilation_rate)

        if self.data_format == 'channels_first':
            output_shape = (input_shape[0], input_shape[1], rows, cols)
        else:
            output_shape = (input_shape[0], rows, cols, input_shape[3])

        return tensor_shape.TensorShape(output_shape)
예제 #5
0
    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            rows = input_shape[2]
            cols = input_shape[3]
            hgts = input_shape[4]
            out_filters = input_shape[1] * self.depth_multiplier
        elif self.data_format == 'channels_last':
            rows = input_shape[1]
            cols = input_shape[2]
            hgts = input_shape[3]
            out_filters = input_shape[4] * self.depth_multiplier

        rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
                                             self.padding,
                                             self.strides[0],
                                             self.dilation_rate[0])
        cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
                                             self.padding,
                                             self.strides[1],
                                             self.dilation_rate[1])
        hgts = conv_utils.conv_output_length(cols, self.kernel_size[2],
                                             self.padding,
                                             self.strides[2],
                                             self.dilation_rate[2])

        if self.data_format == 'channels_first':
            return (input_shape[0], out_filters, rows, cols, hgts)
        elif self.data_format == 'channels_last':
            return (input_shape[0], rows, cols, hgts, out_filters)
예제 #6
0
    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            time = input_shape[2]
            rows = input_shape[3]
            cols = input_shape[4]
        else:
            time = input_shape[1]
            rows = input_shape[2]
            cols = input_shape[3]

        time = conv_utils.conv_output_length(time,
                                             self.pool_size[0],
                                             padding=self.padding,
                                             stride=self.strides[0],
                                             dilation=self.dilation_rate)

        rows = conv_utils.conv_output_length(rows,
                                             self.pool_size[1],
                                             padding=self.padding,
                                             stride=self.strides[1],
                                             dilation=self.dilation_rate)

        cols = conv_utils.conv_output_length(cols,
                                             self.pool_size[2],
                                             padding=self.padding,
                                             stride=self.strides[2],
                                             dilation=self.dilation_rate)

        if self.data_format == 'channels_first':
            output_shape = (input_shape[0], input_shape[1], time, rows, cols)
        else:
            output_shape = (input_shape[0], time, rows, cols, input_shape[4])

        return output_shape
예제 #7
0
파일: pooling.py 프로젝트: xies/deepcell-tf
    def call(self, inputs):
        if self.data_format == 'channels_first':
            inputs = K.permute_dimensions(inputs, pattern=[0, 2, 3, 1])

        dilation_rate = conv_utils.normalize_tuple(
            self.dilation_rate, 2, 'dilation_rate')

        if self.padding == 'valid':
            outputs = tf.nn.pool(inputs,
                                 window_shape=self.pool_size,
                                 pooling_type='MAX',
                                 padding=self.padding.upper(),
                                 dilation_rate=dilation_rate,
                                 strides=self.strides,
                                 data_format='NHWC')

        elif self.padding == 'same':
            input_shape = K.int_shape(inputs)
            rows = input_shape[1]
            cols = input_shape[2]

            rows_unpadded = conv_utils.conv_output_length(
                rows, self.pool_size[0],
                padding='valid',
                stride=self.strides[0],
                dilation=self.dilation_rate)

            cols_unpadded = conv_utils.conv_output_length(
                cols, self.pool_size[1],
                padding='valid',
                stride=self.strides[1],
                dilation=self.dilation_rate)

            w_pad = (rows - rows_unpadded) // 2
            h_pad = (cols - cols_unpadded) // 2

            w_pad = (w_pad, w_pad)
            h_pad = (h_pad, h_pad)

            pattern = [[0, 0], list(w_pad), list(h_pad), [0, 0]]

            # Pad the image
            outputs = tf.pad(inputs, pattern, mode='REFLECT')

            # Perform pooling
            outputs = tf.nn.pool(inputs,
                                 window_shape=self.pool_size,
                                 pooling_type='MAX',
                                 padding='VALID',
                                 dilation_rate=dilation_rate,
                                 strides=self.strides,
                                 data_format='NHWC')

        if self.data_format == 'channels_first':
            outputs = K.permute_dimensions(outputs, pattern=[0, 3, 1, 2])

        return outputs
    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_last':
            space = input_shape[1:-1]
            new_space = []
            for i in range(len(space)):
                new_dim = conv_utils.conv_output_length(
                    space[i],
                    self.kernel_size[i],
                    padding=self.padding,
                    stride=self.strides[i],
                    dilation=self.dilation_rate[i])
                new_space.append(new_dim)
            if isinstance(self.splits, int):
                return [
                    (input_shape[0], ) + tuple(new_space) + (self.filters, )
                    for i in range(self.splits)
                ]
            elif isinstance(self.splits, list):
                n_splt = 1
                for splt in self.splits:
                    if isinstance(splt, int):
                        n_splt = n_splt * splt
                    elif isinstance(splt, list):
                        n_splt = n_splt * len(splt)
                return [
                    (input_shape[0], ) + tuple(new_space) + (self.filters, )
                    for i in range(n_splt)
                ]
            else:
                raise ValueError('splits argument must be integer or list.')

        if self.data_format == 'channels_first':
            space = input_shape[2:]
            new_space = []
            for i in range(len(space)):
                new_dim = conv_utils.conv_output_length(
                    space[i],
                    self.kernel_size[i],
                    padding=self.padding,
                    stride=self.strides[i],
                    dilation=self.dilation_rate[i])
                new_space.append(new_dim)
            if isinstance(self.splits, int):
                return [(input_shape[0], self.filters) + tuple(new_space)
                        for i in range(self.splits)]
            elif isinstance(self.splits, list):
                n_splt = 1
                for splt in self.splits:
                    if isinstance(splt, int):
                        n_splt = n_splt * splt
                    elif isinstance(splt, list):
                        n_splt = n_splt * len(splt)
                return [(input_shape[0], self.filters) + tuple(new_space)
                        for i in range(n_splt)]
            else:
                raise ValueError('splits argument must be integer or list.')
예제 #9
0
 def compute_output_shape(self, input_shape):
     length = conv_output_length(input_shape[1],
                                 self.kernel_size[0],
                                 self.padding,
                                 self.strides[0],
                                 dilation=self.dilation_rate[0])
     return (input_shape[0], length, self.filters)
예제 #10
0
  def get_spatial_out(self,
                      spatial_in: List = None,
                      filter_shape: List = None,
                      strides: List = None,
                      padding: str = None,
                      dilations: List = None) -> List:

    if spatial_in is None:
      spatial_in = self.image_shape

    if filter_shape is None:
      filter_shape = self.patch_shape
    else:
      assert len(filter_shape) == 2

    if strides is None:
      strides = self.strides

    if padding is None:
      padding = self.padding

    if dilations is None:
      dilations = self.dilations

    return [conv_output_length(input_length=spatial_in[i],
                               filter_size=filter_shape[i],
                               stride=strides[i],
                               padding=padding.lower(),
                               dilation=dilations[i]) for i in range(2)]
예제 #11
0
 def compute_output_shape(self, input_shape):
     new_size = conv_utils.conv_output_length(input_shape[1],
                                              self.Filt_dim,
                                              padding="valid",
                                              stride=1,
                                              dilation=1)
     return (input_shape[0], ) + (new_size, ) + (self.N_filt, )
예제 #12
0
 def compute_output_shape(self, input_shape):
     input_shape = tf.TensorShape(input_shape).as_list()
     if self.data_format == 'channels_first':
         rows = input_shape[2]
         cols = input_shape[3]
     else:
         rows = input_shape[1]
         cols = input_shape[2]
     rows = conv_utils.conv_output_length(rows, self.pool_size[0],
                                          self.padding, self.strides[0])
     cols = conv_utils.conv_output_length(cols, self.pool_size[1],
                                          self.padding, self.strides[1])
     if self.data_format == 'channels_first':
         return tf.TensorShape([input_shape[0], input_shape[1], rows, cols])
     else:
         return tf.TensorShape([input_shape[0], rows, cols, input_shape[3]])
 def compute_output_shape(self, input_shape):
     input_shape = tensor_shape.TensorShape(input_shape).as_list()
     length = conv_utils.conv_output_length(input_shape[1],
                                            self.pool_size[0], self.padding,
                                            self.strides[0])
     return tensor_shape.TensorShape(
         [input_shape[0], length, input_shape[2]])
예제 #14
0
 def build(self, input_shape):
   input_dim = input_shape[2]
   if input_dim is None:
     raise ValueError('Axis 2 of input should be fully-defined. '
                      'Found shape:', input_shape)
   output_length = conv_utils.conv_output_length(
       input_shape[1], self.kernel_size[0], self.padding, self.strides[0])
   self.kernel_shape = (output_length, self.kernel_size[0] * input_dim,
                        self.filters)
   self.kernel = self.add_weight(
       shape=self.kernel_shape,
       initializer=self.kernel_initializer,
       name='kernel',
       regularizer=self.kernel_regularizer,
       constraint=self.kernel_constraint)
   if self.use_bias:
     self.bias = self.add_weight(
         shape=(output_length, self.filters),
         initializer=self.bias_initializer,
         name='bias',
         regularizer=self.bias_regularizer,
         constraint=self.bias_constraint)
   else:
     self.bias = None
   self.input_spec = InputSpec(ndim=3, axes={2: input_dim})
   self.built = True
예제 #15
0
    def compute_output_shape(self, input_shape):
        if isinstance(input_shape, list):
            input_shape = input_shape[0]

        cell = self.cell
        if cell.data_format == 'channels_first':
            rows = input_shape[3]
            cols = input_shape[4]
            z = input_shape[5]
        elif cell.data_format == 'channels_last':
            rows = input_shape[2]
            cols = input_shape[3]
            z = input_shape[4]
        rows = conv_utils.conv_output_length(rows,
                                             cell.kernel_size[0],
                                             padding=cell.padding,
                                             stride=cell.strides[0],
                                             dilation=cell.dilation_rate[0])
        cols = conv_utils.conv_output_length(cols,
                                             cell.kernel_size[1],
                                             padding=cell.padding,
                                             stride=cell.strides[1],
                                             dilation=cell.dilation_rate[1])
        z = conv_utils.conv_output_length(z,
                                          cell.kernel_size[2],
                                          padding=cell.padding,
                                          stride=cell.strides[2],
                                          dilation=cell.dilation_rate[2])

        if cell.data_format == 'channels_first':
            output_shape = input_shape[:2] + (cell.filters, rows, cols, z)
        elif cell.data_format == 'channels_last':
            output_shape = input_shape[:2] + (rows, cols, z, cell.filters)

        if not self.return_sequences:
            output_shape = output_shape[:1] + output_shape[2:]

        if self.return_state:
            output_shape = [output_shape]
            if cell.data_format == 'channels_first':
                output_shape += [(input_shape[0], cell.filters, rows, cols, z)
                                 for _ in range(2)]
            elif cell.data_format == 'channels_last':
                output_shape += [(input_shape[0], rows, cols, z, cell.filters)
                                 for _ in range(2)]
        return output_shape
예제 #16
0
    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            rows = input_shape[2]
            cols = input_shape[3]
        elif self.data_format == 'channels_last':
            rows = input_shape[1]
            cols = input_shape[2]

        rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
                                             self.padding, self.strides[0])
        cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
                                             self.padding, self.strides[1])

        if self.data_format == 'channels_first':
            return (input_shape[0], self.filters, rows, cols)
        elif self.data_format == 'channels_last':
            return (input_shape[0], rows, cols, self.filters)
예제 #17
0
  def compute_output_shape(self, input_shape):
    if self.data_format == 'channels_first':
      rows = input_shape[2]
      cols = input_shape[3]
    elif self.data_format == 'channels_last':
      rows = input_shape[1]
      cols = input_shape[2]

    rows = conv_utils.conv_output_length(rows, self.kernel_size[0],
                                         self.padding, self.strides[0])
    cols = conv_utils.conv_output_length(cols, self.kernel_size[1],
                                         self.padding, self.strides[1])

    if self.data_format == 'channels_first':
      return (input_shape[0], self.filters, rows, cols)
    elif self.data_format == 'channels_last':
      return (input_shape[0], rows, cols, self.filters)
예제 #18
0
 def compute_output_shape(self, input_shape):
   input_shape = tensor_shape.TensorShape(input_shape).as_list()
   if self.data_format == 'channels_first':
     rows = input_shape[2]
     cols = input_shape[3]
   else:
     rows = input_shape[1]
     cols = input_shape[2]
   rows = conv_utils.conv_output_length(rows, self.pool_size[0], self.padding,
                                        self.strides[0])
   cols = conv_utils.conv_output_length(cols, self.pool_size[1], self.padding,
                                        self.strides[1])
   if self.data_format == 'channels_first':
     return tensor_shape.TensorShape(
         [input_shape[0], input_shape[1], rows, cols])
   else:
     return tensor_shape.TensorShape(
         [input_shape[0], rows, cols, input_shape[3]])
예제 #19
0
 def _spatial_output_shape(self, spatial_input_shape):
     return [
         conv_utils.conv_output_length(length,
                                       self.kernel_size[i],
                                       padding=self.padding,
                                       stride=self.strides[i],
                                       dilation=self.dilation_rate[i])
         for i, length in enumerate(spatial_input_shape)
     ]
예제 #20
0
 def test_conv_output_length(self):
     self.assertEqual(4, conv_utils.conv_output_length(4, 2, 'same', 1, 1))
     self.assertEqual(2, conv_utils.conv_output_length(4, 2, 'same', 2, 1))
     self.assertEqual(3, conv_utils.conv_output_length(4, 2, 'valid', 1, 1))
     self.assertEqual(2, conv_utils.conv_output_length(4, 2, 'valid', 2, 1))
     self.assertEqual(5, conv_utils.conv_output_length(4, 2, 'full', 1, 1))
     self.assertEqual(3, conv_utils.conv_output_length(4, 2, 'full', 2, 1))
     self.assertEqual(2, conv_utils.conv_output_length(5, 2, 'valid', 2, 2))
예제 #21
0
 def test_conv_output_length(self):
   self.assertEqual(4, conv_utils.conv_output_length(4, 2, 'same', 1, 1))
   self.assertEqual(2, conv_utils.conv_output_length(4, 2, 'same', 2, 1))
   self.assertEqual(3, conv_utils.conv_output_length(4, 2, 'valid', 1, 1))
   self.assertEqual(2, conv_utils.conv_output_length(4, 2, 'valid', 2, 1))
   self.assertEqual(5, conv_utils.conv_output_length(4, 2, 'full', 1, 1))
   self.assertEqual(3, conv_utils.conv_output_length(4, 2, 'full', 2, 1))
   self.assertEqual(2, conv_utils.conv_output_length(5, 2, 'valid', 2, 2))
예제 #22
0
 def build(self, input_shape):
     if self.data_format == 'channels_last':
         input_row, input_col = input_shape[1:-1]
         input_filter = input_shape[3]
     else:
         input_row, input_col = input_shape[2:]
         input_filter = input_shape[1]
     if input_row is None or input_col is None:
         raise ValueError('The spatial dimensions of the inputs to '
                          ' a LocallyConnected2D layer '
                          'should be fully-defined, but layer received '
                          'the inputs shape ' + str(input_shape))
     output_row = conv_utils.conv_output_length(input_row,
                                                self.kernel_size[0],
                                                self.padding,
                                                self.strides[0])
     output_col = conv_utils.conv_output_length(input_col,
                                                self.kernel_size[1],
                                                self.padding,
                                                self.strides[1])
     self.output_row = output_row
     self.output_col = output_col
     self.kernel_shape = (output_row * output_col, self.kernel_size[0] *
                          self.kernel_size[1] * input_filter, self.filters)
     self.kernel = self.add_weight(shape=self.kernel_shape,
                                   initializer=self.kernel_initializer,
                                   name='kernel',
                                   regularizer=self.kernel_regularizer,
                                   constraint=self.kernel_constraint)
     if self.use_bias:
         self.bias = self.add_weight(shape=(output_row, output_col,
                                            self.filters),
                                     initializer=self.bias_initializer,
                                     name='bias',
                                     regularizer=self.bias_regularizer,
                                     constraint=self.bias_constraint)
     else:
         self.bias = None
     if self.data_format == 'channels_first':
         self.input_spec = InputSpec(ndim=4, axes={1: input_filter})
     else:
         self.input_spec = InputSpec(ndim=4, axes={-1: input_filter})
     self.built = True
예제 #23
0
 def compute_output_shape(self, input_shape):
     input_shape = tf.TensorShape(input_shape).as_list()
     space = input_shape[1:-1]
     new_space = []
     for i in range(len(space)):
         new_dim = conv_utils.conv_output_length(space[i],
                                                 self.kernel,
                                                 padding=self.padding,
                                                 stride=self.stride)
         new_space.append(new_dim)
     return tf.TensorShape([input_shape[0]] + new_space + [self.filters])
예제 #24
0
    def get_output_spec(self, input_spec):
        height = conv_output_length(
            input_length=input_spec['shape'][0], filter_size=self.window[0], padding=self.padding,
            stride=self.stride[1], dilation=self.dilation[1]
        )
        width = conv_output_length(
            input_length=input_spec['shape'][1], filter_size=self.window[1], padding=self.padding,
            stride=self.stride[2], dilation=self.dilation[2]
        )
        shape = (height, width)

        if self.squeeze:
            input_spec['shape'] = shape
        else:
            input_spec['shape'] = shape + (self.size,)

        input_spec.pop('min_value', None)
        input_spec.pop('max_value', None)

        return input_spec
예제 #25
0
 def get_new_space(space):
     new_space = []
     for i in range(len(space)):
         new_dim = conv_utils.conv_output_length(
             space[i],
             self.kernel_size[i],
             padding="same",
             stride=self.strides[i],
             dilation=self.dilation_rate[i])
         new_space.append(new_dim)
     return tuple(new_space)
예제 #26
0
 def build(self, input_shape):
   if self.data_format == 'channels_last':
     input_row, input_col = input_shape[1:-1]
     input_filter = input_shape[3]
   else:
     input_row, input_col = input_shape[2:]
     input_filter = input_shape[1]
   if input_row is None or input_col is None:
     raise ValueError('The spatial dimensions of the inputs to '
                      ' a LocallyConnected2D layer '
                      'should be fully-defined, but layer received '
                      'the inputs shape ' + str(input_shape))
   output_row = conv_utils.conv_output_length(input_row, self.kernel_size[0],
                                              self.padding, self.strides[0])
   output_col = conv_utils.conv_output_length(input_col, self.kernel_size[1],
                                              self.padding, self.strides[1])
   self.output_row = output_row
   self.output_col = output_col
   self.kernel_shape = (
       output_row * output_col,
       self.kernel_size[0] * self.kernel_size[1] * input_filter, self.filters)
   self.kernel = self.add_weight(
       shape=self.kernel_shape,
       initializer=self.kernel_initializer,
       name='kernel',
       regularizer=self.kernel_regularizer,
       constraint=self.kernel_constraint)
   if self.use_bias:
     self.bias = self.add_weight(
         shape=(output_row, output_col, self.filters),
         initializer=self.bias_initializer,
         name='bias',
         regularizer=self.bias_regularizer,
         constraint=self.bias_constraint)
   else:
     self.bias = None
   if self.data_format == 'channels_first':
     self.input_spec = InputSpec(ndim=4, axes={1: input_filter})
   else:
     self.input_spec = InputSpec(ndim=4, axes={-1: input_filter})
   self.built = True
예제 #27
0
    def output_spec(self):
        output_spec = super().output_spec()

        height = conv_output_length(
            input_length=output_spec.shape[0], filter_size=self.window[0], padding=self.padding,
            stride=self.stride[1], dilation=self.dilation[1]
        )
        width = conv_output_length(
            input_length=output_spec.shape[1], filter_size=self.window[1], padding=self.padding,
            stride=self.stride[2], dilation=self.dilation[2]
        )

        if self.squeeze:
            output_spec.shape = (height, width)
        else:
            output_spec.shape = (height, width, self.size)

        output_spec.min_value = None
        output_spec.max_value = None

        return output_spec
예제 #28
0
  def compute_output_shape(self, input_shape):
    if self.data_format == 'channels_first':
      input_length = input_shape[2]
    else:
      input_length = input_shape[1]

    length = conv_utils.conv_output_length(input_length, self.kernel_size[0],
                                           self.padding, self.strides[0])

    if self.data_format == 'channels_first':
      return (input_shape[0], self.filters, length)
    elif self.data_format == 'channels_last':
      return (input_shape[0], length, self.filters)
예제 #29
0
  def compute_output_shape(self, input_shape):
    if self.data_format == 'channels_first':
      input_length = input_shape[2]
    else:
      input_length = input_shape[1]

    length = conv_utils.conv_output_length(input_length, self.kernel_size[0],
                                           self.padding, self.strides[0])

    if self.data_format == 'channels_first':
      return (input_shape[0], self.filters, length)
    elif self.data_format == 'channels_last':
      return (input_shape[0], length, self.filters)
  def compute_output_shape(self, input_shape):
    if isinstance(input_shape, list):
      input_shape = input_shape[0]

    cell = self.cell
    if cell.data_format == 'channels_first':
      rows = input_shape[3]
      cols = input_shape[4]
    elif cell.data_format == 'channels_last':
      rows = input_shape[2]
      cols = input_shape[3]
    rows = conv_utils.conv_output_length(rows,
                                         cell.kernel_size[0],
                                         padding=cell.padding,
                                         stride=cell.strides[0],
                                         dilation=cell.dilation_rate[0])
    cols = conv_utils.conv_output_length(cols,
                                         cell.kernel_size[1],
                                         padding=cell.padding,
                                         stride=cell.strides[1],
                                         dilation=cell.dilation_rate[1])

    if cell.data_format == 'channels_first':
      output_shape = input_shape[:2] + (cell.filters, rows, cols)
    elif cell.data_format == 'channels_last':
      output_shape = input_shape[:2] + (rows, cols, cell.filters)

    if not self.return_sequences:
      output_shape = output_shape[:1] + output_shape[2:]

    if self.return_state:
      output_shape = [output_shape]
      if cell.data_format == 'channels_first':
        output_shape += [(input_shape[0], cell.filters, rows, cols)
                         for _ in range(2)]
      elif cell.data_format == 'channels_last':
        output_shape += [(input_shape[0], rows, cols, cell.filters)
                         for _ in range(2)]
    return output_shape
    def compute_output_shape(self, input_shape):
        space = input_shape[1:-2]
        new_space = []
        for i in range(len(space)):
            new_dim = conv_output_length(space[i],
                                         self.kernel_size,
                                         padding=self.padding,
                                         stride=self.strides,
                                         dilation=1)
            new_space.append(new_dim)

        return (input_shape[0], ) + tuple(new_space) + (self.num_capsule,
                                                        self.num_atoms)
예제 #32
0
 def compute_output_shape(self, input_shape):
     space = input_shape[1:3]
     new_space = []
     for i in range(len(space)):
         new_dim = conv_utils.conv_output_length(
             space[i],
             self.kernel_size[i],
             padding=self.padding,
             stride=self.strides[i],
             dilation=self.dilation_rate[i])
         new_space.append(new_dim)
     return (input_shape[0], *new_space, self.num_transformations,
             self.filters)
예제 #33
0
 def compute_output_shape(self, input_shape):
   input_shape = tensor_shape.TensorShape(input_shape).as_list()
   if self.data_format == 'channels_first':
     len_dim1 = input_shape[2]
     len_dim2 = input_shape[3]
     len_dim3 = input_shape[4]
   else:
     len_dim1 = input_shape[1]
     len_dim2 = input_shape[2]
     len_dim3 = input_shape[3]
   len_dim1 = conv_utils.conv_output_length(len_dim1, self.pool_size[0],
                                            self.padding, self.strides[0])
   len_dim2 = conv_utils.conv_output_length(len_dim2, self.pool_size[1],
                                            self.padding, self.strides[1])
   len_dim3 = conv_utils.conv_output_length(len_dim3, self.pool_size[2],
                                            self.padding, self.strides[2])
   if self.data_format == 'channels_first':
     return tensor_shape.TensorShape(
         [input_shape[0], input_shape[1], len_dim1, len_dim2, len_dim3])
   else:
     return tensor_shape.TensorShape(
         [input_shape[0], len_dim1, len_dim2, len_dim3, input_shape[4]])
 def compute_output_shape(self, input_shape):
     space = input_shape[0][1:-1]
     new_space = []
     for i in range(len(space)):
         new_dim = conv_utils.conv_output_length(
             space[i],
             self.kernel_size[i],
             padding='same',
             stride=self.strides[i],
             dilation=self.dilation_rate[i])
         new_space.append(new_dim)
     new_shape = (input_shape[0][0], ) + tuple(new_space) + (self.filters, )
     return [new_shape, new_shape]
예제 #35
0
 def compute_output_shape(self, input_shape):
     input_shape = tensor_shape.TensorShape(input_shape).as_list()
     if self.data_format == 'channels_first':
         len_dim1 = input_shape[2]
         len_dim2 = input_shape[3]
         len_dim3 = input_shape[4]
     else:
         len_dim1 = input_shape[1]
         len_dim2 = input_shape[2]
         len_dim3 = input_shape[3]
     len_dim1 = conv_utils.conv_output_length(len_dim1, self.pool_size[0],
                                              self.padding, self.strides[0])
     len_dim2 = conv_utils.conv_output_length(len_dim2, self.pool_size[1],
                                              self.padding, self.strides[1])
     len_dim3 = conv_utils.conv_output_length(len_dim3, self.pool_size[2],
                                              self.padding, self.strides[2])
     if self.data_format == 'channels_first':
         return tensor_shape.TensorShape(
             [input_shape[0], input_shape[1], len_dim1, len_dim2, len_dim3])
     else:
         return tensor_shape.TensorShape(
             [input_shape[0], len_dim1, len_dim2, len_dim3, input_shape[4]])
예제 #36
0
 def compute_output_shape(self, input_shape):
     input_shape = tensor_shape.TensorShape(input_shape).as_list()
     if self.data_format == 'channels_first':
         steps = input_shape[2]
         features = input_shape[1]
     else:
         steps = input_shape[1]
         features = input_shape[2]
     length = conv_utils.conv_output_length(steps, self.pool_size[0],
                                            self.padding, self.strides[0])
     if self.data_format == 'channels_first':
         return tensor_shape.TensorShape([input_shape[0], features, length])
     else:
         return tensor_shape.TensorShape([input_shape[0], length, features])
    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            length = input_shape[2]
            out_filters = input_shape[1] * self.depth_multiplier
        elif self.data_format == 'channels_last':
            length = input_shape[1]
            out_filters = input_shape[2] * self.depth_multiplier

        length = conv_utils.conv_output_length(length, self.kernel_size,
                                               self.padding, self.strides)
        if self.data_format == 'channels_first':
            return (input_shape[0], out_filters, length)
        elif self.data_format == 'channels_last':
            return (input_shape[0], length, out_filters)
예제 #38
0
 def compute_output_shape(self, input_shape):
   input_shape = tensor_shape.TensorShape(input_shape).as_list()
   if self.data_format == 'channels_first':
     steps = input_shape[2]
     features = input_shape[1]
   else:
     steps = input_shape[1]
     features = input_shape[2]
   length = conv_utils.conv_output_length(steps,
                                          self.pool_size[0],
                                          self.padding,
                                          self.strides[0])
   if self.data_format == 'channels_first':
     return tensor_shape.TensorShape([input_shape[0], features, length])
   else:
     return tensor_shape.TensorShape([input_shape[0], length, features])
예제 #39
0
 def compute_output_shape(self, input_shape):
   input_shape = tensor_shape.TensorShape(input_shape).as_list()
   length = conv_utils.conv_output_length(input_shape[1], self.pool_size[0],
                                          self.padding, self.strides[0])
   return tensor_shape.TensorShape([input_shape[0], length, input_shape[2]])
예제 #40
0
 def compute_output_shape(self, input_shape):
   length = conv_utils.conv_output_length(input_shape[1], self.kernel_size[0],
                                          self.padding, self.strides[0])
   return (input_shape[0], length, self.filters)
예제 #41
0
  def build(self, input_shape):
    if self.data_format == 'channels_first':
      input_dim, input_length = input_shape[1], input_shape[2]
    else:
      input_dim, input_length = input_shape[2], input_shape[1]

    if input_dim is None:
      raise ValueError('Axis 2 of input should be fully-defined. '
                       'Found shape:', input_shape)
    self.output_length = conv_utils.conv_output_length(
        input_length, self.kernel_size[0], self.padding, self.strides[0])

    if self.implementation == 1:
      self.kernel_shape = (self.output_length, self.kernel_size[0] * input_dim,
                           self.filters)

      self.kernel = self.add_weight(
          shape=self.kernel_shape,
          initializer=self.kernel_initializer,
          name='kernel',
          regularizer=self.kernel_regularizer,
          constraint=self.kernel_constraint)

    elif self.implementation == 2:
      if self.data_format == 'channels_first':
        self.kernel_shape = (input_dim, input_length,
                             self.filters, self.output_length)
      else:
        self.kernel_shape = (input_length, input_dim,
                             self.output_length, self.filters)

      self.kernel = self.add_weight(shape=self.kernel_shape,
                                    initializer=self.kernel_initializer,
                                    name='kernel',
                                    regularizer=self.kernel_regularizer,
                                    constraint=self.kernel_constraint)

      self.kernel_mask = get_locallyconnected_mask(
          input_shape=(input_length,),
          kernel_shape=self.kernel_size,
          strides=self.strides,
          padding=self.padding,
          data_format=self.data_format,
          dtype=self.kernel.dtype
      )

    else:
      raise ValueError('Unrecognized implementation mode: %d.'
                       % self.implementation)

    if self.use_bias:
      self.bias = self.add_weight(
          shape=(self.output_length, self.filters),
          initializer=self.bias_initializer,
          name='bias',
          regularizer=self.bias_regularizer,
          constraint=self.bias_constraint)
    else:
      self.bias = None

    if self.data_format == 'channels_first':
      self.input_spec = InputSpec(ndim=3, axes={1: input_dim})
    else:
      self.input_spec = InputSpec(ndim=3, axes={-1: input_dim})
    self.built = True