def conv_block_2d(inputs, filters=128, activation='relu', conv_type='standard', kernel_size=1, strides=1, dilation_rate=1, l2_scale=0, dropout=0, pool_size=1, batch_norm=False, bn_momentum=0.99, bn_gamma='ones', bn_type='standard', symmetric=False): """Construct a single 2D convolution block. """ # flow through variable current current = inputs # activation current = layers.activate(current, activation) # choose convolution type if conv_type == 'separable': conv_layer = tf.keras.layers.SeparableConv2D else: conv_layer = tf.keras.layers.Conv2D # convolution current = conv_layer( filters=filters, kernel_size=kernel_size, strides=strides, padding='same', use_bias=False, dilation_rate=dilation_rate, kernel_initializer='he_normal', kernel_regularizer=tf.keras.regularizers.l2(l2_scale))(current) # batch norm if batch_norm: if bn_type == 'sync': bn_layer = tf.keras.layers.experimental.SyncBatchNormalization else: bn_layer = tf.keras.layers.BatchNormalization current = bn_layer(momentum=bn_momentum, gamma_initializer=bn_gamma)(current) # dropout if dropout > 0: current = tf.keras.layers.Dropout(rate=dropout)(current) # pool if pool_size > 1: current = tf.keras.layers.MaxPool2D(pool_size=pool_size, padding='same')(current) # symmetric if symmetric: current = layers.Symmetrize2D()(current) return current
def conv_block(inputs, filters=None, kernel_size=1, activation='relu', activation_end=None, strides=1, dilation_rate=1, l2_scale=0, dropout=0, conv_type='standard', residual=False, pool_size=1, batch_norm=False, bn_momentum=0.99, bn_gamma=None, bn_type='standard', kernel_initializer='he_normal', padding='same'): """Construct a single convolution block. Args: inputs: [batch_size, seq_length, features] input sequence filters: Conv1D filters kernel_size: Conv1D kernel_size activation: relu/gelu/etc strides: Conv1D strides dilation_rate: Conv1D dilation rate l2_scale: L2 regularization weight. dropout: Dropout rate probability conv_type: Conv1D layer type residual: Residual connection boolean pool_size: Max pool width batch_norm: Apply batch normalization bn_momentum: BatchNorm momentum bn_gamma: BatchNorm gamma (defaults according to residual) Returns: [batch_size, seq_length, features] output sequence """ # flow through variable current current = inputs # choose convolution type if conv_type == 'separable': conv_layer = tf.keras.layers.SeparableConv1D else: conv_layer = tf.keras.layers.Conv1D if filters is None: filters = inputs.shape[-1] # activation current = layers.activate(current, activation) # convolution current = conv_layer( filters=filters, kernel_size=kernel_size, strides=strides, padding='same', use_bias=False, dilation_rate=dilation_rate, kernel_initializer=kernel_initializer, kernel_regularizer=tf.keras.regularizers.l2(l2_scale))(current) # batch norm if batch_norm: if bn_gamma is None: bn_gamma = 'zeros' if residual else 'ones' if bn_type == 'sync': bn_layer = tf.keras.layers.experimental.SyncBatchNormalization else: bn_layer = tf.keras.layers.BatchNormalization current = bn_layer(momentum=bn_momentum, gamma_initializer=bn_gamma)(current) # dropout if dropout > 0: current = tf.keras.layers.Dropout(rate=dropout)(current) # residual add if residual: current = tf.keras.layers.Add()([inputs, current]) # end activation if activation_end is not None: current = layers.activate(current, activation_end) # Pool if pool_size > 1: current = tf.keras.layers.MaxPool1D(pool_size=pool_size, padding=padding)(current) return current
def build_model(self, save_reprs=False): ################################################### # inputs ################################################### sequence = tf.keras.Input(shape=(self.seq_length, 4), name='sequence') current = sequence # augmentation if self.augment_rc: current, reverse_bool = layers.StochasticReverseComplement()( current) if self.augment_shift != [0]: current = layers.StochasticShift(self.augment_shift)(current) self.preds_triu = False ################################################### # build convolution blocks ################################################### for bi, block_params in enumerate(self.trunk): current = self.build_block(current, block_params) # final activation current = layers.activate(current, self.activation) # make model trunk trunk_output = current self.model_trunk = tf.keras.Model(inputs=sequence, outputs=trunk_output) ################################################### # heads ################################################### head_keys = natsorted([v for v in vars(self) if v.startswith('head')]) self.heads = [getattr(self, hk) for hk in head_keys] self.head_output = [] for hi, head in enumerate(self.heads): if not isinstance(head, list): head = [head] # reset to trunk output current = trunk_output # build blocks for bi, block_params in enumerate(head): current = self.build_block(current, block_params) # transform back from reverse complement if self.augment_rc: if self.preds_triu: current = layers.SwitchReverseTriu( self.diagonal_offset)([current, reverse_bool]) else: current = layers.SwitchReverse()([current, reverse_bool]) # save head output self.head_output.append(current) ################################################### # compile model(s) ################################################### self.models = [] for ho in self.head_output: self.models.append(tf.keras.Model(inputs=sequence, outputs=ho)) self.model = self.models[0] #write model summary in file with open( "/mnt/storage/home/psbelokopytova/nn_anopheles/model_summary", "w") as f: self.model.summary(print_fn=lambda x: f.write(x + '\n')) # print(self.model.summary()) ################################################### # track pooling/striding and cropping ################################################### self.model_strides = [] self.target_lengths = [] self.target_crops = [] for model in self.models: self.model_strides.append(1) for layer in self.model.layers: if hasattr(layer, 'strides'): self.model_strides[-1] *= layer.strides[0] if type(sequence.shape[1]) == tf.compat.v1.Dimension: target_full_length = sequence.shape[ 1].value // self.model_strides[-1] else: target_full_length = sequence.shape[1] // self.model_strides[-1] self.target_lengths.append(model.outputs[0].shape[1]) if type(self.target_lengths[-1]) == tf.compat.v1.Dimension: self.target_lengths[-1] = self.target_lengths[-1].value self.target_crops.append( (target_full_length - self.target_lengths[-1]) // 2) print('model_strides', self.model_strides) print('target_lengths', self.target_lengths) print('target_crops', self.target_crops)