示例#1
0
def conv_block_2d(inputs,
                  filters=128,
                  activation='relu',
                  conv_type='standard',
                  kernel_size=1,
                  strides=1,
                  dilation_rate=1,
                  l2_scale=0,
                  dropout=0,
                  pool_size=1,
                  batch_norm=False,
                  bn_momentum=0.99,
                  bn_gamma='ones',
                  symmetric=False):
    """Construct a single 2D convolution block.   """

    # flow through variable current
    current = inputs

    # activation
    current = layers.activate(current, activation)

    # choose convolution type
    if conv_type == 'separable':
        conv_layer = tf.keras.layers.SeparableConv2D
    else:
        conv_layer = tf.keras.layers.Conv2D

    # convolution
    current = conv_layer(
        filters=filters,
        kernel_size=kernel_size,
        strides=strides,
        padding='same',
        use_bias=False,
        dilation_rate=dilation_rate,
        kernel_initializer='he_normal',
        kernel_regularizer=tf.keras.regularizers.l2(l2_scale))(current)

    # batch norm
    if batch_norm:
        current = tf.keras.layers.BatchNormalization(
            momentum=bn_momentum, gamma_initializer=bn_gamma,
            fused=True)(current)

    # dropout
    if dropout > 0:
        current = tf.keras.layers.Dropout(rate=dropout)(current)

    # pool
    if pool_size > 1:
        current = tf.keras.layers.MaxPool2D(pool_size=pool_size,
                                            padding='same')(current)

    # symmetric
    if symmetric:
        current = layers.Symmetrize2D()(current)

    return current
示例#2
0
    def build_model(self, save_reprs=False):
        ###################################################
        # inputs
        ###################################################
        sequence = tf.keras.Input(shape=(self.seq_length, 4), name='sequence')
        # self.genome = tf.keras.Input(shape=(1,), name='genome')
        current = sequence

        # augmentation
        if self.augment_rc:
            current, reverse_bool = layers.StochasticReverseComplement()(
                current)
        current = layers.StochasticShift(self.augment_shift)(current)

        ###################################################
        # build convolution blocks
        ###################################################
        for bi, block_params in enumerate(self.trunk):
            current = self.build_block(current, block_params)

        # final activation
        current = layers.activate(current, self.activation)

        # make model trunk
        trunk_output = current
        self.model_trunk = tf.keras.Model(inputs=sequence,
                                          outputs=trunk_output)

        ###################################################
        # heads
        ###################################################
        self.preds_triu = False

        head_keys = natsorted([v for v in vars(self) if v.startswith('head')])
        self.heads = [getattr(self, hk) for hk in head_keys]

        self.head_output = []
        for hi, head in enumerate(self.heads):
            if not isinstance(head, list):
                head = [head]

            # reset to trunk output
            current = trunk_output

            # build blocks
            for bi, block_params in enumerate(head):
                self.preds_triu |= (block_params['name'] == 'upper_tri')
                current = self.build_block(current, block_params)

            # transform back from reverse complement
            if self.augment_rc:
                if self.preds_triu:
                    current = layers.SwitchReverseTriu(
                        self.diagonal_offset)([current, reverse_bool])
                else:
                    current = layers.SwitchReverse()([current, reverse_bool])

            # save head output
            self.head_output.append(current)

        ###################################################
        # compile model(s)
        ###################################################
        self.models = []
        for ho in self.head_output:
            self.models.append(tf.keras.Model(inputs=sequence, outputs=ho))
        self.model = self.models[0]
        print(self.model.summary())

        ###################################################
        # track pooling/striding and cropping
        ###################################################
        self.model_strides = []
        self.target_lengths = []
        self.target_crops = []
        for model in self.models:
            self.model_strides.append(1)
            for layer in self.model.layers:
                if hasattr(layer, 'strides'):
                    self.model_strides[-1] *= layer.strides[0]
            if type(sequence.shape[1]) == tf.compat.v1.Dimension:
                target_full_length = sequence.shape[
                    1].value // self.model_strides[-1]
            else:
                target_full_length = sequence.shape[1] // self.model_strides[-1]

            self.target_lengths.append(model.outputs[0].shape[1])
            if type(self.target_lengths[-1]) == tf.compat.v1.Dimension:
                self.target_lengths[-1] = self.target_lengths[-1].value
            self.target_crops.append(
                (target_full_length - self.target_lengths[-1]) // 2)
        print('model_strides', self.model_strides)
        print('target_lengths', self.target_lengths)
        print('target_crops', self.target_crops)
示例#3
0
def conv_block(inputs, filters=None, kernel_size=1, activation='relu', activation_end=None,
    strides=1, dilation_rate=1, l2_scale=0, dropout=0, conv_type='standard', residual=False,
    pool_size=1, batch_norm=False, bn_momentum=0.99, bn_gamma=None, bn_type='standard',
    kernel_initializer='he_normal', padding='same'):
  """Construct a single convolution block.

  Args:
    inputs:        [batch_size, seq_length, features] input sequence
    filters:       Conv1D filters
    kernel_size:   Conv1D kernel_size
    activation:    relu/gelu/etc
    strides:       Conv1D strides
    dilation_rate: Conv1D dilation rate
    l2_scale:      L2 regularization weight.
    dropout:       Dropout rate probability
    conv_type:     Conv1D layer type
    residual:      Residual connection boolean
    pool_size:     Max pool width
    batch_norm:    Apply batch normalization
    bn_momentum:   BatchNorm momentum
    bn_gamma:      BatchNorm gamma (defaults according to residual)

  Returns:
    [batch_size, seq_length, features] output sequence
  """

  # flow through variable current
  current = inputs

  # choose convolution type
  if conv_type == 'separable':
    conv_layer = tf.keras.layers.SeparableConv1D
  else:
    conv_layer = tf.keras.layers.Conv1D

  if filters is None:
    filters = inputs.shape[-1]

  # activation
  current = layers.activate(current, activation)

  # convolution
  current = conv_layer(
    filters=filters,
    kernel_size=kernel_size,
    strides=strides,
    padding='same',
    use_bias=False,
    dilation_rate=dilation_rate,
    kernel_initializer=kernel_initializer,
    kernel_regularizer=tf.keras.regularizers.l2(l2_scale))(current)

  # batch norm
  if batch_norm:
    if bn_gamma is None:
      bn_gamma = 'zeros' if residual else 'ones'
    if bn_type == 'sync':
      bn_layer = tf.keras.layers.experimental.SyncBatchNormalization
    else:
      bn_layer = tf.keras.layers.BatchNormalization
    current = bn_layer(
      momentum=bn_momentum,
      gamma_initializer=bn_gamma)(current)

  # dropout
  if dropout > 0:
    current = tf.keras.layers.Dropout(rate=dropout)(current)

  # residual add
  if residual:
    current = tf.keras.layers.Add()([inputs,current])

  # end activation
  if activation_end is not None:
    current = layers.activate(current, activation_end)
    
  # Pool
  if pool_size > 1:
    current = tf.keras.layers.MaxPool1D(
      pool_size=pool_size,
      padding=padding)(current)

  return current
示例#4
0
def dense_block(inputs,
                units=None,
                activation='relu',
                activation_end=None,
                flatten=False,
                dropout=0,
                l2_scale=0,
                l1_scale=0,
                residual=False,
                batch_norm=False,
                bn_momentum=0.99,
                bn_gamma=None,
                bn_type='standard',
                kernel_initializer='he_normal',
                **kwargs):
    """Construct a single convolution block.

  Args:
    inputs:         [batch_size, seq_length, features] input sequence
    units:          Conv1D filters
    activation:     relu/gelu/etc
    activation_end: Compute activation after the other operations
    flatten:        Flatten across positional axis
    dropout:        Dropout rate probability
    l2_scale:       L2 regularization weight.
    l1_scale:       L1 regularization weight.
    residual:       Residual connection boolean
    batch_norm:     Apply batch normalization
    bn_momentum:    BatchNorm momentum
    bn_gamma:       BatchNorm gamma (defaults according to residual)

  Returns:
    [batch_size, seq_length(?), features] output sequence
  """
    current = inputs

    if units is None:
        units = inputs.shape[-1]

    # activation
    current = layers.activate(current, activation)

    # flatten
    if flatten:
        _, seq_len, seq_depth = current.shape
        current = tf.keras.layers.Reshape((
            1,
            seq_len * seq_depth,
        ))(current)

    # dense
    current = tf.keras.layers.Dense(
        units=units,
        use_bias=(not batch_norm),
        kernel_initializer=kernel_initializer,
        kernel_regularizer=tf.keras.regularizers.l1_l2(l1_scale,
                                                       l2_scale))(current)

    # batch norm
    if batch_norm:
        if bn_gamma is None:
            bn_gamma = 'zeros' if residual else 'ones'
        if bn_type == 'sync':
            bn_layer = tf.keras.layers.experimental.SyncBatchNormalization
        else:
            bn_layer = tf.keras.layers.BatchNormalization
        current = bn_layer(momentum=bn_momentum,
                           gamma_initializer=bn_gamma)(current)

    # dropout
    if dropout > 0:
        current = tf.keras.layers.Dropout(rate=dropout)(current)

    # residual add
    if residual:
        current = tf.keras.layers.Add()([inputs, current])

    # end activation
    if activation_end is not None:
        current = layers.activate(current, activation_end)

    return current
示例#5
0
def multihead_attention(inputs,
                        key_size=None,
                        heads=1,
                        out_size=None,
                        num_position_features=None,
                        activation='relu',
                        bn_momentum=0.9,
                        attention_dropout=0,
                        position_dropout=0,
                        dropout=0,
                        dense_expansion=0,
                        **kwargs):
    if out_size is None:
        out_size = inputs.shape[-1]
        value_size = out_size // heads

    # activation
    current = layers.activate(inputs, activation)

    # layer norm
    current = tf.keras.layers.LayerNormalization()(current)

    # multi-head attention
    current = layers.MultiheadAttention(
        value_size=value_size,
        key_size=key_size,
        heads=heads,
        num_position_features=num_position_features,
        attention_dropout_rate=attention_dropout,
        positional_dropout_rate=position_dropout,
        zero_initialize=False)(current)

    # batch norm
    current = tf.keras.layers.BatchNormalization(
        momentum=bn_momentum, gamma_initializer='zeros')(current)

    # dropout
    if dropout > 0:
        current = tf.keras.layers.Dropout(dropout)(current)

    # residual
    current = tf.keras.layers.Add()([inputs, current])

    if dense_expansion == 0:
        final = current
    else:
        current_mha = current

        # layer norm
        current = tf.keras.layers.LayerNormalization()(current)

        # dense
        expansion_filters = int(dense_expansion * out_size)
        current = tf.keras.layers.Dense(expansion_filters)(current)

        # dropout
        if dropout > 0:
            current = tf.keras.layers.Dropout(dropout)(current)

        # activation
        current = layers.activate(current, activation)

        # dense
        current = tf.keras.layers.Dense(out_size)(current)

        # dropout
        if dropout > 0:
            current = tf.keras.layers.Dropout(dropout)(current)

        # residual
        final = tf.keras.layers.Add()([current_mha, current])

    return current