コード例 #1
0
ファイル: decoder.py プロジェクト: zyzhang1992/tfGENTRL
    def weighted_forward(self, sm_list, z):
        '''
        '''
        #zz x = encode(sm_list)[0].to(
        #zz     self.input_embeddings.weight.data.device
        #zz )
        x = encode(sm_list.numpy().astype(str))[0]

        seq_logits = self.get_log_prob(x, z)

        if self.token_weights is not None:
            #zz w = self.token_weights[x[:, 1:].long().contiguous().view(-1)]
            #zz w = w.view_as(seq_logits)
            w = self.token_weights[tf.reshape(tf.cast(x[:, 1:], tf.int64),
                                              shape=-1)]
            w = tf.reshape(w, shape=seq_logits.shape)
            seq_logits = seq_logits * w

        #zz non_eof = (x != self.eos)[:, :-1].float()
        #zz ans_logits = (seq_logits * non_eof).sum(dim=-1)
        #zz ans_logits /= non_eof.sum(dim=-1)
        non_eof = tf.cast((x != self.eos)[:, :-1], tf.float32)
        ans_logits = tf.reduce_sum((seq_logits * non_eof), axis=-1)
        ans_logits /= tf.reduce_sum(non_eof, axis=-1)

        return ans_logits
コード例 #2
0
    def encode(self, sm_list):
        """
        Maps smiles onto a latent space
        """

        tokens, lens = encode(sm_list, self.vocab)
        to_feed = tokens.transpose(1, 0).to(self.embs.weight.device)

        outputs = self.rnn(self.embs(to_feed))[0]
        outputs = outputs[lens, torch.arange(len(lens))]

        return self.final_mlp(outputs)
コード例 #3
0
ファイル: decoder.py プロジェクト: samhitha00/GENTRLS
    def weighted_forward(self, sm_list, z):
        '''
        '''
        x = encode(sm_list)[0].to(self.input_embeddings.weight.data.device)

        seq_logits = self.get_log_prob(x, z)

        if self.token_weights is not None:
            w = self.token_weights[x[:, 1:].long().contiguous().view(-1)]
            w = w.view_as(seq_logits)
            seq_logits = seq_logits * w

        non_eof = (x != self.eos)[:, :-1].float()
        ans_logits = (seq_logits * non_eof).sum(dim=-1)
        ans_logits /= non_eof.sum(dim=-1)

        return ans_logits
コード例 #4
0
ファイル: encoder.py プロジェクト: zyzhang1992/tfGENTRL
    def encode(self, sm_list):
        """
        Maps smiles onto a latent space
        """
        """
        tokens, lens = encode(sm_list)
        to_feed = tokens.transpose(1, 0).to(self.embs.weight.device)

        outputs = self.rnn(self.embs(to_feed))[0]
        outputs = outputs[lens, torch.arange(len(lens))]

        return self.final_mlp(outputs)
        """

        tokens, lens = encode(sm_list.numpy().astype(str))
        to_feed = tokens
        # print(to_feed)
        outputs = self.rnn(self.embs(to_feed))
        idx = [list(a) for a in zip(list(tf.range(len(lens)).numpy()), lens)]
        outputs = tf.gather_nd(outputs, idx)

        return self.final_mlp(outputs)