Example #1
0
 def format_seq(self,seq,stop=False):
     """
     Takes an amino acid sequence, returns a list of integers in the codex of the babbler.
     Here, the default is to strip the stop symbol (stop=False) which would have
     otherwise been added to the end of the sequence. If you are trying to generate
     a rep, do not include the stop. It is probably best to ignore the stop if you are
     co-tuning the babbler and a top model as well.
     """
     if stop:
         int_seq = aa_seq_to_int(seq.strip())
     else:
         int_seq = aa_seq_to_int(seq.strip())[:-1]
     return int_seq
Example #2
0
    def get_rep(self,seq):
        """
        Input a valid amino acid sequence,
        outputs a tuple of average hidden, final hidden, final cell representation arrays.
        Unfortunately, this method accepts one sequence at a time and is as such quite
        slow.
        """
        with tf.compat.v1.Session() as sess:
            initialize_uninitialized(sess)
            # Strip any whitespace and convert to integers with the correct coding
            int_seq = aa_seq_to_int(seq.strip())[:-1]
            # Final state is a cell_state, hidden_state tuple. Output is
            # all hidden states
            final_state_, hs = sess.run(
                [self._final_state, self._output], feed_dict={
                    self._batch_size_placeholder: 1,
                    self._minibatch_x_placeholder: [int_seq],
                    self._initial_state_placeholder: self._zero_state}
            )

        final_cell, final_hidden = final_state_
        # Drop the batch dimension so it is just seq len by
        # representation size
        final_cell = final_cell[0]
        final_hidden = final_hidden[0]
        hs = hs[0]
        avg_hidden = np.mean(hs, axis=0)
        return avg_hidden, final_hidden, final_cell
Example #3
0
def calc_seq_loglike(seq, logits, method='smart', plot=False):
    """
    seq: Amino acid sequence (str)
    logits: A [seq_length x vocab_size] array
    """
    
    # Convert to integer seq, and drop first token since 
    # logits are next char predictions.
    iseq = aa_seq_to_int(seq)[1:]
    iseq = [i-1 for i in iseq] # subtract 1 as logits dont consider pad.
    
    if plot:
        plt.figure(figsize=(20,4))
        plt.imshow(logits.T, aspect='auto')
        for i in range(len(iseq)):
            plt.plot(i, iseq[i], '.k')
        plt.show()
        
    if method == 'dumb':
        sm = scipy.special.softmax(logits, axis=1)
        
        log_like = 0
        for i in range(sm.shape[0]):
            log_like += np.log(sm[i, iseq[i]])
    
    if method == 'smart':
        
        lse = scipy.special.logsumexp(logits, axis=1)
        aa_logits = np.array([logits[i, iseq[i]] for i in range(logits.shape[0])])
        
        log_like = np.sum(aa_logits - lse)                   
            
    return log_like
Example #4
0
    def get_rep(self, seq):
        """
        get_rep needs to be minorly adjusted to accomadate the different state size of the 
        stack.
        Input a valid amino acid sequence, 
        outputs a tuple of average hidden, final hidden, final cell representation arrays.
        Unfortunately, this method accepts one sequence at a time and is as such quite
        slow.
        """
        with tf.Session() as sess:
            initialize_uninitialized(sess)
            # Strip any whitespace and convert to integers with the correct coding
            int_seq = aa_seq_to_int(seq.strip())[:-1]
            # Final state is a cell_state, hidden_state tuple. Output is
            # all hidden states
            final_state_, hs = sess.run(
                [self._final_state, self._output],
                feed_dict={
                    self._batch_size_placeholder: 1,
                    self._minibatch_x_placeholder: [int_seq],
                    self._initial_state_placeholder: self._zero_state
                })

        final_cell, final_hidden = final_state_
        # Because this is a deep model, each of final hidden and final cell is tuple of num_layers
        final_cell = final_cell[-1]
        final_hidden = final_hidden[-1]
        hs = hs[0]
        avg_hidden = np.mean(hs, axis=0)
        return avg_hidden, final_hidden[0], final_cell[0]
Example #5
0
 def get_loss_batch(self, seqs, losstype='avg',
                    batch_size=None):
     # assert all lengths equal
     # ToDo - split seqs into batches
     if batch_size is None:
         batch_size = len(seqs)
     seq_len = len(seqs[0])
     for seq in seqs:
         assert len(seq) == seq_len
     with tf.Session() as sess:
         initialize_uninitialized(sess)
         int_seqs = np.array([aa_seq_to_int(seq.strip())[:-1] for seq in seqs])
         seq_len = len(seqs[0])
         loss = sess.run(
             self._batch_seq_losses, feed_dict={
                 self._batch_size_placeholder: batch_size,
                 self._minibatch_x_placeholder: int_seqs,
                 self._minibatch_y_placeholder: int_seqs,
                 self._initial_state_placeholder: np_zero_state(batch_size,
                                                                self._rnn_size,
                                                                self._num_layers),
                 # self._seq_length_placeholder: [seq_len]*batch_size
             })
     # return np.mean(loss, axis=1)
     return loss
Example #6
0
 def get_logits(self, seq):
     with tf.Session() as sess:
         initialize_uninitialized(sess)
         int_seq = aa_seq_to_int(seq.strip())[:-1]
         logits = sess.run(self._logits, feed_dict={
                 self._batch_size_placeholder: 1,
                 self._minibatch_x_placeholder: [int_seq], # the inputs
                 self._initial_state_placeholder: self._zero_state,
                 }
             )
     return logits
Example #7
0
 def get_pad_targets(self, seq):
     with tf.Session() as sess:
         initialize_uninitialized(sess)
         int_seq = aa_seq_to_int(seq.strip())[:-1]
         raw_ts, padded_ts = sess.run(
             [self._minibatch_y_placeholder, 
              self._pad_adjusted_targets], feed_dict={
              self._minibatch_y_placeholder: [int_seq]
              }
             )
         return raw_ts, padded_ts
Example #8
0
    def get_babble(self, seed, length=250, temp=1):
        """
        Return a babble at temperature temp (on (0,1] with 1 being the noisiest)
        starting with seed and continuing to length length.
        Unfortunately, this method accepts one sequence at a time and is as such quite
        slow.

        """
        with tf.compat.v1.Session() as sess:
            initialize_uninitialized(sess)
            int_seed = aa_seq_to_int(seed.strip())[:-1]

            # No need for padding because this is a single element
            seed_samples, final_state_ = sess.run(
                [self._sample, self._final_state],
                feed_dict={
                    self._minibatch_x_placeholder: [int_seed],
                    self._initial_state_placeholder: self._zero_state,
                    self._batch_size_placeholder: 1,
                    self._temp_placeholder: temp
                }
            )
            # Just the actual character prediction
            pred_int = seed_samples[0, -1] + 1
            seed = seed + int_to_aa[pred_int]

            for i in range(length - len(seed)):
                pred_int, final_state_ = sess.run(
                    [self._sample, self._final_state],
                    feed_dict={
                        self._minibatch_x_placeholder: [[pred_int]],
                        self._initial_state_placeholder: final_state_,
                        self._batch_size_placeholder: 1,
                        self._temp_placeholder: temp
                    }
                )
                pred_int = pred_int[0, 0] + 1
                seed = seed + int_to_aa[pred_int]
        return seed
Example #9
0
 def get_loss(self, seq, losstype='avg'):
     with tf.Session() as sess:
         initialize_uninitialized(sess)
         int_seq = aa_seq_to_int(seq.strip())[:-1]
         if losstype == 'full':
             loss, logits = sess.run(
                 [self._batch_seq_losses, self._logits], feed_dict={
                     self._batch_size_placeholder: 1,
                     self._minibatch_x_placeholder: [int_seq], # the inputs
                     self._initial_state_placeholder: self._zero_state,
                     self._minibatch_y_placeholder: [int_seq] # the targets (what we want to reconstruct)}
                     }
                 )
         elif losstype == 'avg':
             loss, logits = sess.run(
                 [self._loss, self._logits], feed_dict={
                     self._batch_size_placeholder: 1,
                     self._minibatch_x_placeholder: [int_seq], # the inputs
                     self._initial_state_placeholder: self._zero_state,
                     self._minibatch_y_placeholder: [int_seq] # the targets (what we want to reconstruct)}
                 })
     return loss, logits
Example #10
0
    def get_all_hiddens(self, seq_list, sess, return_logits=False):
        """
        Given an amino acid seq list of len <= batch_size, returns a list of 
        hidden state sequences
        """
        int_seq_list = [aa_seq_to_int(s.strip())[:-1] for s in seq_list]

        # Now pad the sequences
        batch = numpy_fillna(int_seq_list)
        nonpad_lens = nonpad_len(batch)
        max_len = batch.shape[1]
        if batch.shape[0] > self._batch_size:
            raise ValueError("The sequence batch is large than batch size")
        elif batch.shape[0] < self._batch_size:
            missing = self._batch_size - batch.shape[0]
            mask = np.array(([True] * batch.shape[0]) + ([False] * missing))
            batch = np.concatenate([batch, np.zeros((missing, max_len))],
                                   axis=0)
        elif batch.shape[0] == self._batch_size:
            mask = np.array(([True] * batch.shape[0]))

        if return_logits:
            hiddens, logits = sess.run(
                [self._output, self._logits],
                feed_dict={
                    self._minibatch_x_placeholder: batch,
                    self._initial_state_placeholder: self._zero_state,
                    self._batch_size_placeholder: self._batch_size
                })
        else:
            hiddens = sess.run(self._output,
                               feed_dict={
                                   self._minibatch_x_placeholder: batch,
                                   self._initial_state_placeholder:
                                   self._zero_state,
                                   self._batch_size_placeholder:
                                   self._batch_size
                               })

        assert hiddens.shape[
            0] == self._batch_size, "Dimension 0 does not match batch size"
        assert hiddens.shape[
            1] == max_len, "Dimension 1 does not match max_len"
        assert hiddens.shape[
            2] == self._rnn_size, "Dimension 2 does not match rnn_size"
        hiddens = hiddens[
            mask, :, :]  # Mask away the zeros padding batch dimension
        result = []
        for i, row in enumerate(hiddens):
            # Row is seq_len x rnn_size
            row = row[:nonpad_lens[i], :]
            assert row.shape[0] == (
                len(seq_list[i]) +
                1), "Hidden sequence {0} the incorrect length".format(i)
            assert row.shape[
                1] == self._rnn_size, "Hidden state {0} the wrong dimension".format(
                    i)
            result.append(row)

        if return_logits:
            logits = logits[mask, :, :]
            logit_result = []
            for i, row in enumerate(logits):
                row = row[:nonpad_lens[i], :]
                assert row.shape[0] == (
                    len(seq_list[i]) +
                    1), "Logits mat {0} the incorrect length".format(i)
                assert row.shape[
                    1] == self._vocab_size - 1, "Logits mat {0} the wrong dimension".format(
                        i)
                logit_result.append(row)

        if return_logits:
            return result, logit_result
        else:
            return result