def build_ensemble(self, ensemble_rc=False, ensemble_shifts=[0]): """ Build ensemble of models computing on augmented input sequences. """ if ensemble_rc or len(ensemble_shifts) > 1: # sequence input sequence = tf.keras.Input(shape=(self.seq_length, 4), name='sequence') sequences = [sequence] if len(ensemble_shifts) > 1: # generate shifted sequences sequences = layers.EnsembleShift(ensemble_shifts)(sequences) if ensemble_rc: # generate reverse complements and indicators sequences_rev = layers.EnsembleReverseComplement()(sequences) else: sequences_rev = [(seq,tf.constant(False)) for seq in sequences] # predict each sequence if self.preds_triu: preds = [layers.SwitchReverseTriu(self.diagonal_offset) ([self.model(seq), rp]) for (seq,rp) in sequences_rev] else: preds = [layers.SwitchReverse()([self.model(seq), rp]) for (seq,rp) in sequences_rev] # create layer preds_avg = tf.keras.layers.Average()(preds) # create meta model self.ensemble = tf.keras.Model(inputs=sequence, outputs=preds_avg)
def build_model(self, save_reprs=False): ################################################### # inputs ################################################### sequence = tf.keras.Input(shape=(self.seq_length, 4), name='sequence') # self.genome = tf.keras.Input(shape=(1,), name='genome') current = sequence # augmentation if self.augment_rc: current, reverse_bool = layers.StochasticReverseComplement()( current) current = layers.StochasticShift(self.augment_shift)(current) ################################################### # build convolution blocks ################################################### for bi, block_params in enumerate(self.trunk): current = self.build_block(current, block_params) # final activation current = layers.activate(current, self.activation) # make model trunk trunk_output = current self.model_trunk = tf.keras.Model(inputs=sequence, outputs=trunk_output) ################################################### # heads ################################################### self.preds_triu = False head_keys = natsorted([v for v in vars(self) if v.startswith('head')]) self.heads = [getattr(self, hk) for hk in head_keys] self.head_output = [] for hi, head in enumerate(self.heads): if not isinstance(head, list): head = [head] # reset to trunk output current = trunk_output # build blocks for bi, block_params in enumerate(head): self.preds_triu |= (block_params['name'] == 'upper_tri') current = self.build_block(current, block_params) # transform back from reverse complement if self.augment_rc: if self.preds_triu: current = layers.SwitchReverseTriu( self.diagonal_offset)([current, reverse_bool]) else: current = layers.SwitchReverse()([current, reverse_bool]) # save head output self.head_output.append(current) ################################################### # compile model(s) ################################################### self.models = [] for ho in self.head_output: self.models.append(tf.keras.Model(inputs=sequence, outputs=ho)) self.model = self.models[0] print(self.model.summary()) ################################################### # track pooling/striding and cropping ################################################### self.model_strides = [] self.target_lengths = [] self.target_crops = [] for model in self.models: self.model_strides.append(1) for layer in self.model.layers: if hasattr(layer, 'strides'): self.model_strides[-1] *= layer.strides[0] if type(sequence.shape[1]) == tf.compat.v1.Dimension: target_full_length = sequence.shape[ 1].value // self.model_strides[-1] else: target_full_length = sequence.shape[1] // self.model_strides[-1] self.target_lengths.append(model.outputs[0].shape[1]) if type(self.target_lengths[-1]) == tf.compat.v1.Dimension: self.target_lengths[-1] = self.target_lengths[-1].value self.target_crops.append( (target_full_length - self.target_lengths[-1]) // 2) print('model_strides', self.model_strides) print('target_lengths', self.target_lengths) print('target_crops', self.target_crops)