def weighted_forward(self, sm_list, z): ''' ''' #zz x = encode(sm_list)[0].to( #zz self.input_embeddings.weight.data.device #zz ) x = encode(sm_list.numpy().astype(str))[0] seq_logits = self.get_log_prob(x, z) if self.token_weights is not None: #zz w = self.token_weights[x[:, 1:].long().contiguous().view(-1)] #zz w = w.view_as(seq_logits) w = self.token_weights[tf.reshape(tf.cast(x[:, 1:], tf.int64), shape=-1)] w = tf.reshape(w, shape=seq_logits.shape) seq_logits = seq_logits * w #zz non_eof = (x != self.eos)[:, :-1].float() #zz ans_logits = (seq_logits * non_eof).sum(dim=-1) #zz ans_logits /= non_eof.sum(dim=-1) non_eof = tf.cast((x != self.eos)[:, :-1], tf.float32) ans_logits = tf.reduce_sum((seq_logits * non_eof), axis=-1) ans_logits /= tf.reduce_sum(non_eof, axis=-1) return ans_logits
def encode(self, sm_list): """ Maps smiles onto a latent space """ tokens, lens = encode(sm_list, self.vocab) to_feed = tokens.transpose(1, 0).to(self.embs.weight.device) outputs = self.rnn(self.embs(to_feed))[0] outputs = outputs[lens, torch.arange(len(lens))] return self.final_mlp(outputs)
def weighted_forward(self, sm_list, z): ''' ''' x = encode(sm_list)[0].to(self.input_embeddings.weight.data.device) seq_logits = self.get_log_prob(x, z) if self.token_weights is not None: w = self.token_weights[x[:, 1:].long().contiguous().view(-1)] w = w.view_as(seq_logits) seq_logits = seq_logits * w non_eof = (x != self.eos)[:, :-1].float() ans_logits = (seq_logits * non_eof).sum(dim=-1) ans_logits /= non_eof.sum(dim=-1) return ans_logits
def encode(self, sm_list): """ Maps smiles onto a latent space """ """ tokens, lens = encode(sm_list) to_feed = tokens.transpose(1, 0).to(self.embs.weight.device) outputs = self.rnn(self.embs(to_feed))[0] outputs = outputs[lens, torch.arange(len(lens))] return self.final_mlp(outputs) """ tokens, lens = encode(sm_list.numpy().astype(str)) to_feed = tokens # print(to_feed) outputs = self.rnn(self.embs(to_feed)) idx = [list(a) for a in zip(list(tf.range(len(lens)).numpy()), lens)] outputs = tf.gather_nd(outputs, idx) return self.final_mlp(outputs)