def conv_block_2d(inputs, filters=128, activation='relu', conv_type='standard', kernel_size=1, strides=1, dilation_rate=1, l2_scale=0, dropout=0, pool_size=1, batch_norm=False, bn_momentum=0.99, bn_gamma='ones', symmetric=False): """Construct a single 2D convolution block. """ # flow through variable current current = inputs # activation current = layers.activate(current, activation) # choose convolution type if conv_type == 'separable': conv_layer = tf.keras.layers.SeparableConv2D else: conv_layer = tf.keras.layers.Conv2D # convolution current = conv_layer( filters=filters, kernel_size=kernel_size, strides=strides, padding='same', use_bias=False, dilation_rate=dilation_rate, kernel_initializer='he_normal', kernel_regularizer=tf.keras.regularizers.l2(l2_scale))(current) # batch norm if batch_norm: current = tf.keras.layers.BatchNormalization( momentum=bn_momentum, gamma_initializer=bn_gamma, fused=True)(current) # dropout if dropout > 0: current = tf.keras.layers.Dropout(rate=dropout)(current) # pool if pool_size > 1: current = tf.keras.layers.MaxPool2D(pool_size=pool_size, padding='same')(current) # symmetric if symmetric: current = layers.Symmetrize2D()(current) return current
def build_model(self, save_reprs=False): ################################################### # inputs ################################################### sequence = tf.keras.Input(shape=(self.seq_length, 4), name='sequence') # self.genome = tf.keras.Input(shape=(1,), name='genome') current = sequence # augmentation if self.augment_rc: current, reverse_bool = layers.StochasticReverseComplement()( current) current = layers.StochasticShift(self.augment_shift)(current) ################################################### # build convolution blocks ################################################### for bi, block_params in enumerate(self.trunk): current = self.build_block(current, block_params) # final activation current = layers.activate(current, self.activation) # make model trunk trunk_output = current self.model_trunk = tf.keras.Model(inputs=sequence, outputs=trunk_output) ################################################### # heads ################################################### self.preds_triu = False head_keys = natsorted([v for v in vars(self) if v.startswith('head')]) self.heads = [getattr(self, hk) for hk in head_keys] self.head_output = [] for hi, head in enumerate(self.heads): if not isinstance(head, list): head = [head] # reset to trunk output current = trunk_output # build blocks for bi, block_params in enumerate(head): self.preds_triu |= (block_params['name'] == 'upper_tri') current = self.build_block(current, block_params) # transform back from reverse complement if self.augment_rc: if self.preds_triu: current = layers.SwitchReverseTriu( self.diagonal_offset)([current, reverse_bool]) else: current = layers.SwitchReverse()([current, reverse_bool]) # save head output self.head_output.append(current) ################################################### # compile model(s) ################################################### self.models = [] for ho in self.head_output: self.models.append(tf.keras.Model(inputs=sequence, outputs=ho)) self.model = self.models[0] print(self.model.summary()) ################################################### # track pooling/striding and cropping ################################################### self.model_strides = [] self.target_lengths = [] self.target_crops = [] for model in self.models: self.model_strides.append(1) for layer in self.model.layers: if hasattr(layer, 'strides'): self.model_strides[-1] *= layer.strides[0] if type(sequence.shape[1]) == tf.compat.v1.Dimension: target_full_length = sequence.shape[ 1].value // self.model_strides[-1] else: target_full_length = sequence.shape[1] // self.model_strides[-1] self.target_lengths.append(model.outputs[0].shape[1]) if type(self.target_lengths[-1]) == tf.compat.v1.Dimension: self.target_lengths[-1] = self.target_lengths[-1].value self.target_crops.append( (target_full_length - self.target_lengths[-1]) // 2) print('model_strides', self.model_strides) print('target_lengths', self.target_lengths) print('target_crops', self.target_crops)
def conv_block(inputs, filters=None, kernel_size=1, activation='relu', activation_end=None, strides=1, dilation_rate=1, l2_scale=0, dropout=0, conv_type='standard', residual=False, pool_size=1, batch_norm=False, bn_momentum=0.99, bn_gamma=None, bn_type='standard', kernel_initializer='he_normal', padding='same'): """Construct a single convolution block. Args: inputs: [batch_size, seq_length, features] input sequence filters: Conv1D filters kernel_size: Conv1D kernel_size activation: relu/gelu/etc strides: Conv1D strides dilation_rate: Conv1D dilation rate l2_scale: L2 regularization weight. dropout: Dropout rate probability conv_type: Conv1D layer type residual: Residual connection boolean pool_size: Max pool width batch_norm: Apply batch normalization bn_momentum: BatchNorm momentum bn_gamma: BatchNorm gamma (defaults according to residual) Returns: [batch_size, seq_length, features] output sequence """ # flow through variable current current = inputs # choose convolution type if conv_type == 'separable': conv_layer = tf.keras.layers.SeparableConv1D else: conv_layer = tf.keras.layers.Conv1D if filters is None: filters = inputs.shape[-1] # activation current = layers.activate(current, activation) # convolution current = conv_layer( filters=filters, kernel_size=kernel_size, strides=strides, padding='same', use_bias=False, dilation_rate=dilation_rate, kernel_initializer=kernel_initializer, kernel_regularizer=tf.keras.regularizers.l2(l2_scale))(current) # batch norm if batch_norm: if bn_gamma is None: bn_gamma = 'zeros' if residual else 'ones' if bn_type == 'sync': bn_layer = tf.keras.layers.experimental.SyncBatchNormalization else: bn_layer = tf.keras.layers.BatchNormalization current = bn_layer( momentum=bn_momentum, gamma_initializer=bn_gamma)(current) # dropout if dropout > 0: current = tf.keras.layers.Dropout(rate=dropout)(current) # residual add if residual: current = tf.keras.layers.Add()([inputs,current]) # end activation if activation_end is not None: current = layers.activate(current, activation_end) # Pool if pool_size > 1: current = tf.keras.layers.MaxPool1D( pool_size=pool_size, padding=padding)(current) return current
def dense_block(inputs, units=None, activation='relu', activation_end=None, flatten=False, dropout=0, l2_scale=0, l1_scale=0, residual=False, batch_norm=False, bn_momentum=0.99, bn_gamma=None, bn_type='standard', kernel_initializer='he_normal', **kwargs): """Construct a single convolution block. Args: inputs: [batch_size, seq_length, features] input sequence units: Conv1D filters activation: relu/gelu/etc activation_end: Compute activation after the other operations flatten: Flatten across positional axis dropout: Dropout rate probability l2_scale: L2 regularization weight. l1_scale: L1 regularization weight. residual: Residual connection boolean batch_norm: Apply batch normalization bn_momentum: BatchNorm momentum bn_gamma: BatchNorm gamma (defaults according to residual) Returns: [batch_size, seq_length(?), features] output sequence """ current = inputs if units is None: units = inputs.shape[-1] # activation current = layers.activate(current, activation) # flatten if flatten: _, seq_len, seq_depth = current.shape current = tf.keras.layers.Reshape(( 1, seq_len * seq_depth, ))(current) # dense current = tf.keras.layers.Dense( units=units, use_bias=(not batch_norm), kernel_initializer=kernel_initializer, kernel_regularizer=tf.keras.regularizers.l1_l2(l1_scale, l2_scale))(current) # batch norm if batch_norm: if bn_gamma is None: bn_gamma = 'zeros' if residual else 'ones' if bn_type == 'sync': bn_layer = tf.keras.layers.experimental.SyncBatchNormalization else: bn_layer = tf.keras.layers.BatchNormalization current = bn_layer(momentum=bn_momentum, gamma_initializer=bn_gamma)(current) # dropout if dropout > 0: current = tf.keras.layers.Dropout(rate=dropout)(current) # residual add if residual: current = tf.keras.layers.Add()([inputs, current]) # end activation if activation_end is not None: current = layers.activate(current, activation_end) return current
def multihead_attention(inputs, key_size=None, heads=1, out_size=None, num_position_features=None, activation='relu', bn_momentum=0.9, attention_dropout=0, position_dropout=0, dropout=0, dense_expansion=0, **kwargs): if out_size is None: out_size = inputs.shape[-1] value_size = out_size // heads # activation current = layers.activate(inputs, activation) # layer norm current = tf.keras.layers.LayerNormalization()(current) # multi-head attention current = layers.MultiheadAttention( value_size=value_size, key_size=key_size, heads=heads, num_position_features=num_position_features, attention_dropout_rate=attention_dropout, positional_dropout_rate=position_dropout, zero_initialize=False)(current) # batch norm current = tf.keras.layers.BatchNormalization( momentum=bn_momentum, gamma_initializer='zeros')(current) # dropout if dropout > 0: current = tf.keras.layers.Dropout(dropout)(current) # residual current = tf.keras.layers.Add()([inputs, current]) if dense_expansion == 0: final = current else: current_mha = current # layer norm current = tf.keras.layers.LayerNormalization()(current) # dense expansion_filters = int(dense_expansion * out_size) current = tf.keras.layers.Dense(expansion_filters)(current) # dropout if dropout > 0: current = tf.keras.layers.Dropout(dropout)(current) # activation current = layers.activate(current, activation) # dense current = tf.keras.layers.Dense(out_size)(current) # dropout if dropout > 0: current = tf.keras.layers.Dropout(dropout)(current) # residual final = tf.keras.layers.Add()([current_mha, current]) return current