예제 #1
0
def conv_block_2d(inputs,
                  filters=128,
                  activation='relu',
                  conv_type='standard',
                  kernel_size=1,
                  strides=1,
                  dilation_rate=1,
                  l2_scale=0,
                  dropout=0,
                  pool_size=1,
                  batch_norm=False,
                  bn_momentum=0.99,
                  bn_gamma='ones',
                  bn_type='standard',
                  symmetric=False):
    """Construct a single 2D convolution block.   """

    # flow through variable current
    current = inputs

    # activation
    current = layers.activate(current, activation)

    # choose convolution type
    if conv_type == 'separable':
        conv_layer = tf.keras.layers.SeparableConv2D
    else:
        conv_layer = tf.keras.layers.Conv2D

    # convolution
    current = conv_layer(
        filters=filters,
        kernel_size=kernel_size,
        strides=strides,
        padding='same',
        use_bias=False,
        dilation_rate=dilation_rate,
        kernel_initializer='he_normal',
        kernel_regularizer=tf.keras.regularizers.l2(l2_scale))(current)

    # batch norm
    if batch_norm:
        if bn_type == 'sync':
            bn_layer = tf.keras.layers.experimental.SyncBatchNormalization
        else:
            bn_layer = tf.keras.layers.BatchNormalization
        current = bn_layer(momentum=bn_momentum,
                           gamma_initializer=bn_gamma)(current)

    # dropout
    if dropout > 0:
        current = tf.keras.layers.Dropout(rate=dropout)(current)

    # pool
    if pool_size > 1:
        current = tf.keras.layers.MaxPool2D(pool_size=pool_size,
                                            padding='same')(current)

    # symmetric
    if symmetric:
        current = layers.Symmetrize2D()(current)

    return current
예제 #2
0
def conv_block(inputs,
               filters=None,
               kernel_size=1,
               activation='relu',
               activation_end=None,
               strides=1,
               dilation_rate=1,
               l2_scale=0,
               dropout=0,
               conv_type='standard',
               residual=False,
               pool_size=1,
               batch_norm=False,
               bn_momentum=0.99,
               bn_gamma=None,
               bn_type='standard',
               kernel_initializer='he_normal',
               padding='same'):
    """Construct a single convolution block.

  Args:
    inputs:        [batch_size, seq_length, features] input sequence
    filters:       Conv1D filters
    kernel_size:   Conv1D kernel_size
    activation:    relu/gelu/etc
    strides:       Conv1D strides
    dilation_rate: Conv1D dilation rate
    l2_scale:      L2 regularization weight.
    dropout:       Dropout rate probability
    conv_type:     Conv1D layer type
    residual:      Residual connection boolean
    pool_size:     Max pool width
    batch_norm:    Apply batch normalization
    bn_momentum:   BatchNorm momentum
    bn_gamma:      BatchNorm gamma (defaults according to residual)

  Returns:
    [batch_size, seq_length, features] output sequence
  """

    # flow through variable current
    current = inputs

    # choose convolution type
    if conv_type == 'separable':
        conv_layer = tf.keras.layers.SeparableConv1D
    else:
        conv_layer = tf.keras.layers.Conv1D

    if filters is None:
        filters = inputs.shape[-1]

    # activation
    current = layers.activate(current, activation)

    # convolution
    current = conv_layer(
        filters=filters,
        kernel_size=kernel_size,
        strides=strides,
        padding='same',
        use_bias=False,
        dilation_rate=dilation_rate,
        kernel_initializer=kernel_initializer,
        kernel_regularizer=tf.keras.regularizers.l2(l2_scale))(current)

    # batch norm
    if batch_norm:
        if bn_gamma is None:
            bn_gamma = 'zeros' if residual else 'ones'
        if bn_type == 'sync':
            bn_layer = tf.keras.layers.experimental.SyncBatchNormalization
        else:
            bn_layer = tf.keras.layers.BatchNormalization
        current = bn_layer(momentum=bn_momentum,
                           gamma_initializer=bn_gamma)(current)

    # dropout
    if dropout > 0:
        current = tf.keras.layers.Dropout(rate=dropout)(current)

    # residual add
    if residual:
        current = tf.keras.layers.Add()([inputs, current])

    # end activation
    if activation_end is not None:
        current = layers.activate(current, activation_end)

    # Pool
    if pool_size > 1:
        current = tf.keras.layers.MaxPool1D(pool_size=pool_size,
                                            padding=padding)(current)

    return current
예제 #3
0
파일: seqnn.py 프로젝트: polyaB/basenji
    def build_model(self, save_reprs=False):
        ###################################################
        # inputs
        ###################################################
        sequence = tf.keras.Input(shape=(self.seq_length, 4), name='sequence')
        current = sequence

        # augmentation
        if self.augment_rc:
            current, reverse_bool = layers.StochasticReverseComplement()(
                current)
        if self.augment_shift != [0]:
            current = layers.StochasticShift(self.augment_shift)(current)
        self.preds_triu = False

        ###################################################
        # build convolution blocks
        ###################################################
        for bi, block_params in enumerate(self.trunk):
            current = self.build_block(current, block_params)

        # final activation
        current = layers.activate(current, self.activation)

        # make model trunk
        trunk_output = current
        self.model_trunk = tf.keras.Model(inputs=sequence,
                                          outputs=trunk_output)

        ###################################################
        # heads
        ###################################################
        head_keys = natsorted([v for v in vars(self) if v.startswith('head')])
        self.heads = [getattr(self, hk) for hk in head_keys]

        self.head_output = []
        for hi, head in enumerate(self.heads):
            if not isinstance(head, list):
                head = [head]

            # reset to trunk output
            current = trunk_output

            # build blocks
            for bi, block_params in enumerate(head):
                current = self.build_block(current, block_params)

            # transform back from reverse complement
            if self.augment_rc:
                if self.preds_triu:
                    current = layers.SwitchReverseTriu(
                        self.diagonal_offset)([current, reverse_bool])
                else:
                    current = layers.SwitchReverse()([current, reverse_bool])

            # save head output
            self.head_output.append(current)

        ###################################################
        # compile model(s)
        ###################################################
        self.models = []
        for ho in self.head_output:
            self.models.append(tf.keras.Model(inputs=sequence, outputs=ho))
        self.model = self.models[0]
        #write model summary in file
        with open(
                "/mnt/storage/home/psbelokopytova/nn_anopheles/model_summary",
                "w") as f:
            self.model.summary(print_fn=lambda x: f.write(x + '\n'))
        # print(self.model.summary())

        ###################################################
        # track pooling/striding and cropping
        ###################################################
        self.model_strides = []
        self.target_lengths = []
        self.target_crops = []
        for model in self.models:
            self.model_strides.append(1)
            for layer in self.model.layers:
                if hasattr(layer, 'strides'):
                    self.model_strides[-1] *= layer.strides[0]
            if type(sequence.shape[1]) == tf.compat.v1.Dimension:
                target_full_length = sequence.shape[
                    1].value // self.model_strides[-1]
            else:
                target_full_length = sequence.shape[1] // self.model_strides[-1]

            self.target_lengths.append(model.outputs[0].shape[1])
            if type(self.target_lengths[-1]) == tf.compat.v1.Dimension:
                self.target_lengths[-1] = self.target_lengths[-1].value
            self.target_crops.append(
                (target_full_length - self.target_lengths[-1]) // 2)
        print('model_strides', self.model_strides)
        print('target_lengths', self.target_lengths)
        print('target_crops', self.target_crops)