def call(self, inputs):
   if self._reshape_required:
     reshaped_inputs = []
     input_ndims = list(map(K.ndim, inputs))
     if None not in input_ndims:
       # If ranks of all inputs are available,
       # we simply expand each of them at axis=1
       # until all of them have the same rank.
       max_ndim = max(input_ndims)
       for x in inputs:
         x_ndim = K.ndim(x)
         for _ in range(max_ndim - x_ndim):
           x = K.expand_dims(x, 1)
         reshaped_inputs.append(x)
       return self._merge_function(reshaped_inputs)
     else:
       # Transpose all inputs so that batch size is the last dimension.
       # (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... , batch_size)
       transposed = False
       for x in inputs:
         x_ndim = K.ndim(x)
         if x_ndim is None:
           x_shape = K.shape(x)
           batch_size = x_shape[0]
           new_shape = K.concatenate([x_shape[1:], K.expand_dims(batch_size)])
           x_transposed = K.reshape(x,
                                    K.stack([batch_size,
                                             K.prod(x_shape[1:])]))
           x_transposed = K.permute_dimensions(x_transposed, (1, 0))
           x_transposed = K.reshape(x_transposed, new_shape)
           reshaped_inputs.append(x_transposed)
           transposed = True
         elif x_ndim > 1:
           dims = list(range(1, x_ndim)) + [0]
           reshaped_inputs.append(K.permute_dimensions(x, dims))
           transposed = True
         else:
           # We don't transpose inputs if they are 1D vectors or scalars.
           reshaped_inputs.append(x)
       y = self._merge_function(reshaped_inputs)
       y_ndim = K.ndim(y)
       if transposed:
         # If inputs have been transposed, we have to transpose the output too.
         if y_ndim is None:
           y_shape = K.shape(y)
           y_ndim = K.shape(y_shape)[0]
           batch_size = y_shape[y_ndim - 1]
           new_shape = K.concatenate(
               [K.expand_dims(batch_size), y_shape[:y_ndim - 1]])
           y = K.reshape(y, (-1, batch_size))
           y = K.permute_dimensions(y, (1, 0))
           y = K.reshape(y, new_shape)
         elif y_ndim > 1:
           dims = [y_ndim - 1] + list(range(y_ndim - 1))
           y = K.permute_dimensions(y, dims)
       return y
   else:
     return self._merge_function(inputs)
    def f_lrn(X):
        samples, rows, cols, channels = X.get_shape()
        half_n = n // 2
        X_square = K.square(X)

        # pad with zeros (half_n channels)
        X_square_padded = K.spatial_2d_padding(
            K.permute_dimensions(X_square, (1, 0, 3, 2)),
            ((0, 0), (half_n, half_n)))  # pad 2nd and 3rd dim
        X_square_padded = K.permute_dimensions(X_square_padded, (1, 0, 3, 2))

        # sum runs over n "adjacent" kernel maps at the same spatial position
        sum = 0
        for i in range(n):
            sum += X_square_padded[:, :, :, i:(i + int(channels))]
        scale = (k + alpha * sum)**beta
        return X / scale
Пример #3
0
  def call(self, inputs):
    stride = self.strides[0]
    output_length, feature_dim, filters = self.kernel_shape

    xs = []
    for i in range(output_length):
      slice_length = slice(i * stride, i * stride + self.kernel_size[0])
      xs.append(K.reshape(inputs[:, slice_length, :], (1, -1, feature_dim)))
    x_aggregate = K.concatenate(xs, axis=0)
    # Shape: `(output_length, batch_size, filters)`.
    output = K.batch_dot(x_aggregate, self.kernel)
    output = K.permute_dimensions(output, (1, 0, 2))

    if self.use_bias:
      output += K.reshape(self.bias, (1, output_length, filters))
    if self.activation is not None:
      output = self.activation(output)
    return output
Пример #4
0
  def call(self, inputs):
    stride = self.strides[0]
    output_length, feature_dim, filters = self.kernel_shape

    xs = []
    for i in range(output_length):
      slice_length = slice(i * stride, i * stride + self.kernel_size[0])
      xs.append(K.reshape(inputs[:, slice_length, :], (1, -1, feature_dim)))
    x_aggregate = K.concatenate(xs, axis=0)
    # Shape: `(output_length, batch_size, filters)`.
    output = K.batch_dot(x_aggregate, self.kernel)
    output = K.permute_dimensions(output, (1, 0, 2))

    if self.use_bias:
      output += K.reshape(self.bias, (1, output_length, filters))
    if self.activation is not None:
      output = self.activation(output)
    return output
Пример #5
0
 def call(self, inputs):
   return K.permute_dimensions(inputs, (0,) + self.dims)
Пример #6
0
  def call(self, inputs):
    stride_row, stride_col = self.strides
    _, feature_dim, filters = self.kernel_shape

    if self.data_format == 'channels_first':
      if K.backend() == 'theano':
        output = []
        for i in range(self.output_row):
          for j in range(self.output_col):
            slice_row = slice(i * stride_row,
                              i * stride_row + self.kernel_size[0])
            slice_col = slice(j * stride_col,
                              j * stride_col + self.kernel_size[1])
            x_flatten = K.reshape(inputs[:, :, slice_row, slice_col],
                                  (1, -1, feature_dim))
            output.append(
                K.dot(x_flatten, self.kernel[i * self.output_col + j, :, :]))
        output = K.concatenate(output, axis=0)
      else:
        xs = []
        for i in range(self.output_row):
          for j in range(self.output_col):
            slice_row = slice(i * stride_row,
                              i * stride_row + self.kernel_size[0])
            slice_col = slice(j * stride_col,
                              j * stride_col + self.kernel_size[1])
            xs.append(
                K.reshape(inputs[:, :, slice_row, slice_col], (1, -1,
                                                               feature_dim)))
        x_aggregate = K.concatenate(xs, axis=0)
        output = K.batch_dot(x_aggregate, self.kernel)
      output = K.reshape(output, (self.output_row, self.output_col, -1,
                                  filters))
      output = K.permute_dimensions(output, (2, 3, 0, 1))

    elif self.data_format == 'channels_last':
      xs = []
      for i in range(self.output_row):
        for j in range(self.output_col):
          slice_row = slice(i * stride_row,
                            i * stride_row + self.kernel_size[0])
          slice_col = slice(j * stride_col,
                            j * stride_col + self.kernel_size[1])
          xs.append(
              K.reshape(inputs[:, slice_row, slice_col, :], (1, -1, feature_dim
                                                            )))
      x_aggregate = K.concatenate(xs, axis=0)
      output = K.batch_dot(x_aggregate, self.kernel)
      output = K.reshape(output, (self.output_row, self.output_col, -1,
                                  filters))
      output = K.permute_dimensions(output, (2, 0, 1, 3))

    if self.use_bias:
      if self.data_format == 'channels_first':
        output += K.reshape(self.bias, (1, filters, self.output_row,
                                        self.output_col))
      elif self.data_format == 'channels_last':
        output += K.reshape(self.bias, (1, self.output_row, self.output_col,
                                        filters))
    output = self.activation(output)
    return output
Пример #7
0
 def call(self, inputs):
   return K.permute_dimensions(inputs, (0,) + self.dims)
Пример #8
0
  def call(self, inputs):
    stride_row, stride_col = self.strides
    _, feature_dim, filters = self.kernel_shape

    if self.data_format == 'channels_first':
      if K.backend() == 'theano':
        output = []
        for i in range(self.output_row):
          for j in range(self.output_col):
            slice_row = slice(i * stride_row,
                              i * stride_row + self.kernel_size[0])
            slice_col = slice(j * stride_col,
                              j * stride_col + self.kernel_size[1])
            x_flatten = K.reshape(inputs[:, :, slice_row, slice_col],
                                  (1, -1, feature_dim))
            output.append(
                K.dot(x_flatten, self.kernel[i * self.output_col + j, :, :]))
        output = K.concatenate(output, axis=0)
      else:
        xs = []
        for i in range(self.output_row):
          for j in range(self.output_col):
            slice_row = slice(i * stride_row,
                              i * stride_row + self.kernel_size[0])
            slice_col = slice(j * stride_col,
                              j * stride_col + self.kernel_size[1])
            xs.append(
                K.reshape(inputs[:, :, slice_row, slice_col], (1, -1,
                                                               feature_dim)))
        x_aggregate = K.concatenate(xs, axis=0)
        output = K.batch_dot(x_aggregate, self.kernel)
      output = K.reshape(output, (self.output_row, self.output_col, -1,
                                  filters))
      output = K.permute_dimensions(output, (2, 3, 0, 1))

    elif self.data_format == 'channels_last':
      xs = []
      for i in range(self.output_row):
        for j in range(self.output_col):
          slice_row = slice(i * stride_row,
                            i * stride_row + self.kernel_size[0])
          slice_col = slice(j * stride_col,
                            j * stride_col + self.kernel_size[1])
          xs.append(
              K.reshape(inputs[:, slice_row, slice_col, :], (1, -1, feature_dim
                                                            )))
      x_aggregate = K.concatenate(xs, axis=0)
      output = K.batch_dot(x_aggregate, self.kernel)
      output = K.reshape(output, (self.output_row, self.output_col, -1,
                                  filters))
      output = K.permute_dimensions(output, (2, 0, 1, 3))

    if self.use_bias:
      if self.data_format == 'channels_first':
        output += K.reshape(self.bias, (1, filters, self.output_row,
                                        self.output_col))
      elif self.data_format == 'channels_last':
        output += K.reshape(self.bias, (1, self.output_row, self.output_col,
                                        filters))
    output = self.activation(output)
    return output