Example #1
0
  def build_ensemble(self, ensemble_rc=False, ensemble_shifts=[0]):
    """ Build ensemble of models computing on augmented input sequences. """
    if ensemble_rc or len(ensemble_shifts) > 1:
      # sequence input
      sequence = tf.keras.Input(shape=(self.seq_length, 4), name='sequence')
      sequences = [sequence]

      if len(ensemble_shifts) > 1:
        # generate shifted sequences
        sequences = layers.EnsembleShift(ensemble_shifts)(sequences)

      if ensemble_rc:
        # generate reverse complements and indicators
        sequences_rev = layers.EnsembleReverseComplement()(sequences)
      else:
        sequences_rev = [(seq,tf.constant(False)) for seq in sequences]

      # predict each sequence
      if self.preds_triu:
        preds = [layers.SwitchReverseTriu(self.diagonal_offset)
                  ([self.model(seq), rp]) for (seq,rp) in sequences_rev]
      else:
        preds = [layers.SwitchReverse()([self.model(seq), rp]) for (seq,rp) in sequences_rev]

      # create layer
      preds_avg = tf.keras.layers.Average()(preds)      

      # create meta model
      self.ensemble = tf.keras.Model(inputs=sequence, outputs=preds_avg)
Example #2
0
    def build_model(self, save_reprs=False):
        ###################################################
        # inputs
        ###################################################
        sequence = tf.keras.Input(shape=(self.seq_length, 4), name='sequence')
        # self.genome = tf.keras.Input(shape=(1,), name='genome')
        current = sequence

        # augmentation
        if self.augment_rc:
            current, reverse_bool = layers.StochasticReverseComplement()(
                current)
        current = layers.StochasticShift(self.augment_shift)(current)

        ###################################################
        # build convolution blocks
        ###################################################
        for bi, block_params in enumerate(self.trunk):
            current = self.build_block(current, block_params)

        # final activation
        current = layers.activate(current, self.activation)

        # make model trunk
        trunk_output = current
        self.model_trunk = tf.keras.Model(inputs=sequence,
                                          outputs=trunk_output)

        ###################################################
        # heads
        ###################################################
        self.preds_triu = False

        head_keys = natsorted([v for v in vars(self) if v.startswith('head')])
        self.heads = [getattr(self, hk) for hk in head_keys]

        self.head_output = []
        for hi, head in enumerate(self.heads):
            if not isinstance(head, list):
                head = [head]

            # reset to trunk output
            current = trunk_output

            # build blocks
            for bi, block_params in enumerate(head):
                self.preds_triu |= (block_params['name'] == 'upper_tri')
                current = self.build_block(current, block_params)

            # transform back from reverse complement
            if self.augment_rc:
                if self.preds_triu:
                    current = layers.SwitchReverseTriu(
                        self.diagonal_offset)([current, reverse_bool])
                else:
                    current = layers.SwitchReverse()([current, reverse_bool])

            # save head output
            self.head_output.append(current)

        ###################################################
        # compile model(s)
        ###################################################
        self.models = []
        for ho in self.head_output:
            self.models.append(tf.keras.Model(inputs=sequence, outputs=ho))
        self.model = self.models[0]
        print(self.model.summary())

        ###################################################
        # track pooling/striding and cropping
        ###################################################
        self.model_strides = []
        self.target_lengths = []
        self.target_crops = []
        for model in self.models:
            self.model_strides.append(1)
            for layer in self.model.layers:
                if hasattr(layer, 'strides'):
                    self.model_strides[-1] *= layer.strides[0]
            if type(sequence.shape[1]) == tf.compat.v1.Dimension:
                target_full_length = sequence.shape[
                    1].value // self.model_strides[-1]
            else:
                target_full_length = sequence.shape[1] // self.model_strides[-1]

            self.target_lengths.append(model.outputs[0].shape[1])
            if type(self.target_lengths[-1]) == tf.compat.v1.Dimension:
                self.target_lengths[-1] = self.target_lengths[-1].value
            self.target_crops.append(
                (target_full_length - self.target_lengths[-1]) // 2)
        print('model_strides', self.model_strides)
        print('target_lengths', self.target_lengths)
        print('target_crops', self.target_crops)