Exemple #1
0
    def forward(self, state):
        input, message, f_atoms, f_bonds, a2a, a2b, b2a, b2revb, a_scope = state

        a2x = a2a if self.atom_messages else a2b
        nei_a_message = index_select_ND(
            message, a2x)  # num_atoms x max_num_bonds x hidden
        a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
        a_input = T.cat([f_atoms, a_message],
                        dim=1)  # num_atoms x (atom_fdim + hidden)
        atom_hiddens = self.act_func(self.W_o(a_input))  # num_atoms x hidden
        atom_hiddens = self.dropout_layer(atom_hiddens)  # num_atoms x hidden

        # Readout
        mol_vecs = []
        for i, (a_start, a_size) in enumerate(a_scope):
            if a_size == 0:
                mol_vecs.append(self.cached_zero_vector)
            else:
                cur_hiddens = atom_hiddens.narrow(0, a_start, a_size)
                mol_vec = cur_hiddens  # (num_atoms, hidden_size)

                mol_vec = mol_vec.sum(dim=0) / a_size
                #mol_vec = mol_vec.mean(0).values
                mol_vecs.append(mol_vec)

        mol_vecs = T.stack(mol_vecs, dim=0)  # (num_molecules, hidden_size)

        return mol_vecs
Exemple #2
0
    def forward(self,mol_graph: BatchMolGraph, features_batch=None) -> torch.FloatTensor:

        f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, bonds = mol_graph.get_components()
        if self.args.cuda or next(self.parameters()).is_cuda:
            f_atoms, f_bonds, a2b, b2a, b2revb = (
                    f_atoms.cuda(), f_bonds.cuda(), 
                    a2b.cuda(), b2a.cuda(), b2revb.cuda())
            
        # Input
        input_atom = self.W_i_atom(f_atoms)  # num_atoms x hidden_size
        input_atom = self.act_func(input_atom)
        message_atom = input_atom.clone()
        
        input_bond = self.W_i_bond(f_bonds)  # num_bonds x hidden_size
        message_bond = self.act_func(input_bond)
        input_bond = self.act_func(input_bond)
        # Message passing
        for depth in range(self.depth - 1):
            agg_message = index_select_ND(message_bond, a2b)
            agg_message = agg_message.sum(dim=1) * agg_message.max(dim=1)[0]
            message_atom = message_atom + agg_message
            
            # directed graph
            rev_message = message_bond[b2revb]  # num_bonds x hidden
            message_bond = message_atom[b2a] - rev_message  # num_bonds x hidden
            
            message_bond = self._modules[f'W_h_{depth}'](message_bond)
            message_bond = self.dropout_layer(self.act_func(input_bond + message_bond))
        
        agg_message = index_select_ND(message_bond, a2b)
        agg_message = agg_message.sum(dim=1) * agg_message.max(dim=1)[0]
        agg_message = self.lr(torch.cat([agg_message, message_atom, input_atom], 1))
        agg_message = self.gru(agg_message, a_scope)
        
        atom_hiddens = self.act_func(self.W_o(agg_message))  # num_atoms x hidden
        atom_hiddens = self.dropout_layer(atom_hiddens)  # num_atoms x hidden
        
        # Readout
        mol_vecs = []
        for i, (a_start, a_size) in enumerate(a_scope):
            if a_size == 0:
                assert 0
            cur_hiddens = atom_hiddens.narrow(0, a_start, a_size)
            mol_vecs.append(cur_hiddens.mean(0))
        mol_vecs = torch.stack(mol_vecs, dim=0)
        
        return mol_vecs  # B x H
Exemple #3
0
    def forward(self, fnode: torch.Tensor, fmess: torch.Tensor,
                node_graph: torch.Tensor, mess_graph: torch.Tensor,
                scope: List[Tuple[int, int]]) -> torch.Tensor:
        messages = torch.zeros(mess_graph.size(0), self.hidden_size)

        if next(self.parameters()).is_cuda:
            fnode, fmess, node_graph, mess_graph, messages = fnode.cuda(
            ), fmess.cuda(), node_graph.cuda(), mess_graph.cuda(
            ), messages.cuda()

        fnode = self.embedding(fnode)
        fmess = index_select_ND(fnode, fmess)
        messages = self.GRU(messages, fmess, mess_graph)

        mess_nei = index_select_ND(messages, node_graph)
        fnode = torch.cat([fnode, mess_nei.sum(dim=1)], dim=-1)
        fnode = self.outputNN(fnode)
        tree_vec = []
        for st, le in scope:
            tree_vec.append(fnode.narrow(0, st, le).mean(dim=0))

        return torch.stack(tree_vec, dim=0)
Exemple #4
0
    def forward(self, state, context):
        input, message, f_atoms, f_bonds, a2a, a2b, b2a, b2revb, a_scope = state

        expanded_context = _build_expanded_context(context, a_scope)
        if not self.atom_messages:
            expanded_context = expanded_context[b2a]

        # Message passing
        for depth in range(self.depth):
            if self.undirected:
                message = (message + message[b2revb]) / 2

            if self.atom_messages:
                nei_a_message = index_select_ND(
                    message, a2a)  # num_atoms x max_num_bonds x hidden
                nei_f_bonds = index_select_ND(
                    f_bonds, a2b)  # num_atoms x max_num_bonds x bond_fdim
                nei_message = T.cat(
                    (nei_a_message, nei_f_bonds),
                    dim=2)  # num_atoms x max_num_bonds x (hidden + bond_fdim)
                message = nei_message.sum(
                    dim=1)  # num_atoms x hidden + bond_fdim
            else:
                # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1)
                # message      a_message = sum(nei_a_message)      rev_message
                nei_a_message = index_select_ND(
                    message, a2b)  # num_atoms x max_num_bonds x hidden
                a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
                rev_message = message[b2revb]  # num_bonds x hidden
                message = a_message[b2a] - rev_message  # num_bonds x hidden

            message = self.W_h(T.cat((message, expanded_context), -1))
            message = self.act_func(input + message)  # num_bonds x hidden_size
            message = self.dropout_layer(message)  # num_bonds x hidden

        return (input, message, f_atoms, f_bonds, a2a, a2b, b2a, b2revb,
                a_scope)
Exemple #5
0
    def forward(self, smiles_batch: List[str]):
        # Get MolTrees with memoization
        mol_batch = [
            SMILES_TO_MOLTREE[smiles] if smiles in SMILES_TO_MOLTREE else
            SMILES_TO_MOLTREE.setdefault(smiles, MolTree(smiles))
            for smiles in smiles_batch
        ]
        fnode, fmess, node_graph, mess_graph, scope = self.tensorize(mol_batch)

        if next(self.parameters()).is_cuda:
            fnode, fmess, node_graph, mess_graph = fnode.cuda(), fmess.cuda(
            ), node_graph.cuda(), mess_graph.cuda()

        fnode = self.embedding(fnode)
        fmess = index_select_ND(fnode, fmess)
        tree_vec = self.jtnn((fnode, fmess, node_graph, mess_graph, scope, []))
        mol_vec = self.mpn(smiles_batch)

        return torch.cat([tree_vec, mol_vec], dim=-1)
Exemple #6
0
    def forward(self,
                mol_graph: BatchMolGraph,
                features_batch: List[np.ndarray] = None) -> torch.FloatTensor:
        """
        Encodes a batch of molecular graphs.

        :param mol_graph: A BatchMolGraph representing a batch of molecular graphs.
        :param features_batch: A list of ndarrays containing additional features.
        :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule.
        """
        if self.use_input_features:
            features_batch = torch.from_numpy(
                np.stack(features_batch)).float().to(self.device)

            if self.features_only:
                return features_batch

        f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope = mol_graph.get_components(
            atom_messages=self.atom_messages)
        f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.to(
            self.device), f_bonds.to(self.device), a2b.to(self.device), b2a.to(
                self.device), b2revb.to(self.device)

        if self.atom_messages:
            a2a = mol_graph.get_a2a().to(self.device)

        # Input
        if self.atom_messages:
            input = self.W_i(f_atoms)  # num_atoms x hidden_size
        else:
            input = self.W_i(f_bonds)  # num_bonds x hidden_size
        message = self.act_func(input)  # num_bonds x hidden_size

        # Message passing
        for depth in range(self.depth - 1):
            if self.undirected:
                message = (message + message[b2revb]) / 2

            if self.atom_messages:
                nei_a_message = index_select_ND(
                    message, a2a)  # num_atoms x max_num_bonds x hidden
                nei_f_bonds = index_select_ND(
                    f_bonds, a2b)  # num_atoms x max_num_bonds x bond_fdim
                nei_message = torch.cat(
                    (nei_a_message, nei_f_bonds),
                    dim=2)  # num_atoms x max_num_bonds x hidden + bond_fdim
                message = nei_message.sum(
                    dim=1)  # num_atoms x hidden + bond_fdim
            else:
                # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1)
                # message      a_message = sum(nei_a_message)      rev_message
                nei_a_message = index_select_ND(
                    message, a2b)  # num_atoms x max_num_bonds x hidden
                a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
                rev_message = message[b2revb]  # num_bonds x hidden
                message = a_message[b2a] - rev_message  # num_bonds x hidden

            message = self.W_h(message)
            message = self.act_func(input + message)  # num_bonds x hidden_size
            message = self.dropout_layer(message)  # num_bonds x hidden

        a2x = a2a if self.atom_messages else a2b
        nei_a_message = index_select_ND(
            message, a2x)  # num_atoms x max_num_bonds x hidden
        a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
        a_input = torch.cat([f_atoms, a_message],
                            dim=1)  # num_atoms x (atom_fdim + hidden)
        atom_hiddens = self.act_func(self.W_o(a_input))  # num_atoms x hidden
        atom_hiddens = self.dropout_layer(atom_hiddens)  # num_atoms x hidden

        # Readout
        mol_vecs = []
        for i, (a_start, a_size) in enumerate(a_scope):
            if a_size == 0:
                mol_vecs.append(self.cached_zero_vector)
            else:
                cur_hiddens = atom_hiddens.narrow(0, a_start, a_size)
                mol_vec = cur_hiddens  # (num_atoms, hidden_size)

                mol_vec = mol_vec.sum(dim=0) / a_size
                mol_vecs.append(mol_vec)

        mol_vecs = torch.stack(mol_vecs, dim=0)  # (num_molecules, hidden_size)

        if self.use_input_features:
            features_batch = features_batch.to(mol_vecs)
            if len(features_batch.shape) == 1:
                features_batch = features_batch.view(
                    [1, features_batch.shape[0]])
            mol_vecs = torch.cat([mol_vecs, features_batch],
                                 dim=1)  # (num_molecules, hidden_size)

        return mol_vecs  # num_molecules x hidden
Exemple #7
0
    def forward(
            self,
            mol_graph: BatchMolGraph,
            atom_descriptors_batch: List[np.ndarray] = None
    ) -> torch.FloatTensor:
        """
        Encodes a batch of molecular graphs.

        :param mol_graph: A :class:`~chemprop.features.featurization.BatchMolGraph` representing
                          a batch of molecular graphs.
        :param atom_descriptors_batch: A list of numpy arrays containing additional atomic descriptors
        :return: A PyTorch tensor of shape :code:`(num_molecules, hidden_size)` containing the encoding of each molecule.
        """
        if atom_descriptors_batch is not None:
            atom_descriptors_batch = [
                np.zeros([1, atom_descriptors_batch[0].shape[1]])
            ] + atom_descriptors_batch  # padding the first with 0 to match the atom_hiddens
            atom_descriptors_batch = (torch.from_numpy(
                np.concatenate(atom_descriptors_batch,
                               axis=0)).float().to(self.device))

        f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope = mol_graph.get_components(
            atom_messages=self.atom_messages)
        f_atoms, f_bonds, a2b, b2a, b2revb = (
            f_atoms.to(self.device),
            f_bonds.to(self.device),
            a2b.to(self.device),
            b2a.to(self.device),
            b2revb.to(self.device),
        )

        if self.atom_messages:
            a2a = mol_graph.get_a2a().to(self.device)

        # Input
        if self.atom_messages:
            input = self.W_i(f_atoms)  # num_atoms x hidden_size
        else:
            input = self.W_i(f_bonds)  # num_bonds x hidden_size
        message = self.act_func(input)  # num_bonds x hidden_size

        # Message passing
        for depth in range(self.depth - 1):
            if self.undirected:
                message = (message + message[b2revb]) / 2

            if self.atom_messages:
                nei_a_message = index_select_ND(
                    message, a2a)  # num_atoms x max_num_bonds x hidden
                nei_f_bonds = index_select_ND(
                    f_bonds, a2b)  # num_atoms x max_num_bonds x bond_fdim
                nei_message = torch.cat(
                    (nei_a_message, nei_f_bonds),
                    dim=2)  # num_atoms x max_num_bonds x hidden + bond_fdim
                message = nei_message.sum(
                    dim=1)  # num_atoms x hidden + bond_fdim
            else:
                # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1)
                # message      a_message = sum(nei_a_message)      rev_message
                nei_a_message = index_select_ND(
                    message, a2b)  # num_atoms x max_num_bonds x hidden
                a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
                rev_message = message[b2revb]  # num_bonds x hidden
                message = a_message[b2a] - rev_message  # num_bonds x hidden

            message = self.W_h(message)
            message = self.act_func(input + message)  # num_bonds x hidden_size
            message = self.dropout_layer(message)  # num_bonds x hidden

        a2x = a2a if self.atom_messages else a2b
        nei_a_message = index_select_ND(
            message, a2x)  # num_atoms x max_num_bonds x hidden
        a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
        a_input = torch.cat([f_atoms, a_message],
                            dim=1)  # num_atoms x (atom_fdim + hidden)
        atom_hiddens = self.act_func(self.W_o(a_input))  # num_atoms x hidden
        atom_hiddens = self.dropout_layer(atom_hiddens)  # num_atoms x hidden

        # concatenate the atom descriptors
        if atom_descriptors_batch is not None:
            if len(atom_hiddens) != len(atom_descriptors_batch):
                raise ValueError(
                    f"The number of atoms is different from the length of the extra atom features"
                )

            atom_hiddens = torch.cat(
                [atom_hiddens, atom_descriptors_batch],
                dim=1)  # num_atoms x (hidden + descriptor size)
            atom_hiddens = self.atom_descriptors_layer(
                atom_hiddens)  # num_atoms x (hidden + descriptor size)
            atom_hiddens = self.dropout_layer(
                atom_hiddens)  # num_atoms x (hidden + descriptor size)

        # Readout
        mol_vecs = []
        for i, (a_start, a_size) in enumerate(a_scope):
            if a_size == 0:
                mol_vecs.append(self.cached_zero_vector)
            else:
                cur_hiddens = atom_hiddens.narrow(0, a_start, a_size)
                mol_vec = cur_hiddens  # (num_atoms, hidden_size)
                if self.aggregation == "mean":
                    mol_vec = mol_vec.sum(dim=0) / a_size
                elif self.aggregation == "sum":
                    mol_vec = mol_vec.sum(dim=0)
                elif self.aggregation == "norm":
                    mol_vec = mol_vec.sum(dim=0) / self.aggregation_norm
                mol_vecs.append(mol_vec)

        mol_vecs = torch.stack(mol_vecs, dim=0)  # (num_molecules, hidden_size)

        return mol_vecs  # num_molecules x hidden
    def forward(self,
                mol_graph: BatchMolGraph,
                features_batch: List[np.ndarray] = None,
                sample = False) -> torch.FloatTensor:
        """
        Encodes a batch of molecular graphs.

        :param mol_graph: A BatchMolGraph representing a batch of molecular graphs.
        :param features_batch: A list of ndarrays containing additional features.
        :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule.
        """

        f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope = mol_graph.get_components(atom_messages=self.atom_messages)
        f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.to(self.device), f_bonds.to(self.device), a2b.to(self.device), b2a.to(self.device), b2revb.to(self.device)
        f_atoms_or_bonds = f_atoms if self.atom_messages else f_bonds
        
        
        
        
        ##### LAYER FOR HIDDEN STATE INITIALISATION #####
        input, kl = self.W_i(f_atoms_or_bonds, sample)
        tkl = kl
        message = self.act_func(input)  # num_bonds x hidden_size
        #################################################
        
        
        
        mol_vecs_list = []        
        # Message passing
        for depth in range(self.depth_max):
            
            if depth != 0:
                nei_a_message = index_select_ND(message, a2b)  # num_atoms x max_num_bonds x hidden
                a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
                rev_message = message[b2revb]  # num_bonds x hidden
                message = a_message[b2a] - rev_message  # num_bonds x hidden
                
                ##### LAYER FOR HIDDEN STATE UPDATES #####
                message, kl = self.W_h(message, sample)
                if depth == self.depth_max - 1:
                    tkl += kl # ONLY ADD ON KL LOSS ONCE
                
                message = self.act_func(input + message)  # num_bonds x hidden_size
                message = self.dropout_layer(message)  # num_bonds x hidden
                ##########################################
        
        

            # save outputs for final 4 depths
            if depth >= self.depth_min - 1:
            
                nei_a_message = index_select_ND(message, a2b)  # num_atoms x max_num_bonds x hidden
                a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
                a_input = torch.cat([f_atoms, a_message], dim=1)  # num_atoms x (atom_fdim + hidden)
                
                ##### LAYER FOR ATOM REPRESENTATION #####
                atom_hiddens, kl = self.W_o(a_input, sample)
                if depth == self.depth_max - 1:
                    tkl += kl # ONLY ADD ON KL LOSS ONCE       
                atom_hiddens = self.act_func(atom_hiddens)  # num_atoms x hidden
                atom_hiddens = self.dropout_layer(atom_hiddens)  # num_atoms x hidden
                #########################################
                
                # Readout
                mol_vecs = []
                for i, (a_start, a_size) in enumerate(a_scope):
                    if a_size == 0:
                        mol_vecs.append(self.cached_zero_vector)
                    else:
                        cur_hiddens = atom_hiddens.narrow(0, a_start, a_size)
                        mol_vec = cur_hiddens  # (num_atoms, hidden_size)
                        mol_vec = mol_vec.sum(dim=0) / a_size
                        mol_vecs.append(mol_vec)
                mol_vecs = torch.stack(mol_vecs, dim=0)  # (num_molecules, hidden_size)
                mol_vecs_list.append(mol_vecs)
        



        return mol_vecs_list, tkl
Exemple #9
0
    def forward(self,
                mol_graph: BatchMolGraph,
                features_batch: List[np.ndarray] = None) -> torch.FloatTensor:
        """
        Encodes a batch of molecular graphs.

        :param mol_graph: A BatchMolGraph representing a batch of molecular graphs.
        :param features_batch: A list of ndarrays containing additional features.
        :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule.
        """
        if self.use_input_features:
            features_batch = torch.from_numpy(np.stack(features_batch)).float()

            if self.args.cuda:
                features_batch = features_batch.cuda()

            if self.features_only:
                return features_batch

        f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, smiles_batch = mol_graph.get_components(
        )
        print('f_atoms', f_atoms.shape)
        print(f_atoms[0, :])
        print(f_atoms[1, :])
        print('f_bonds', f_bonds.shape)
        print(f_bonds[0, :])
        print('a2b', a2b)
        print('b2a', b2a)
        print('b2revb', b2revb)
        print('a_scope', a_scope)
        print('b_scope', b_scope)

        if self.atom_messages:
            a2a = mol_graph.get_a2a()

        if self.args.cuda or next(self.parameters()).is_cuda:
            f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.cuda(), f_bonds.cuda(
            ), a2b.cuda(), b2a.cuda(), b2revb.cuda()

            if self.atom_messages:
                a2a = a2a.cuda()

        # Input
        if self.atom_messages:
            input = self.W_i(f_atoms)  # num_atoms x hidden_size
        else:
            input = self.W_i(f_bonds)  # num_bonds x hidden_size
        print('smiles_batch', smiles_batch)
        print('f_bonds', f_bonds.shape)
        print('\n\n\n\n')
        # print('input', input.shape)
        message = self.act_func(input)  # num_bonds x hidden_size
        # print('before', message[0, :2])
        # message[0, 0] = message[0, 0] + 0.001
        # print('after', message[0, :2])
        # print('message', message, message.shape)
        # print('\n\n')

        # Message passing
        for depth in range(self.depth - 1):  # GCNN
            # print('depth', depth)
            if self.undirected:
                message = (message + message[b2revb]) / 2

            if self.atom_messages:
                nei_a_message = index_select_ND(
                    message, a2a)  # num_atoms x max_num_bonds x hidden
                nei_f_bonds = index_select_ND(
                    f_bonds, a2b)  # num_atoms x max_num_bonds x bond_fdim
                nei_message = torch.cat(
                    (nei_a_message, nei_f_bonds),
                    dim=2)  # num_atoms x max_num_bonds x hidden + bond_fdim
                message = nei_message.sum(
                    dim=1)  # num_atoms x hidden + bond_fdim
            else:
                # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1)
                # message      a_message = sum(nei_a_message)      rev_message
                nei_a_message = index_select_ND(
                    message, a2b)  # num_atoms x max_num_bonds x hidden
                a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
                rev_message = message[b2revb]  # num_bonds x hidden
                message = a_message[b2a] - rev_message  # num_bonds x hidden

            message = self.W_h(message)
            message = self.act_func(input + message)  # num_bonds x hidden_size
            message = self.dropout_layer(message)  # num_bonds x hidden

        a2x = a2a if self.atom_messages else a2b
        nei_a_message = index_select_ND(
            message, a2x)  # num_atoms x max_num_bonds x hidden
        a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
        a_input = torch.cat(
            [f_atoms, a_message],
            dim=1)  # num_atoms x (atom_fdim + hidden)  133 + 1000
        atom_hiddens = self.act_func(self.W_o(a_input))  # num_atoms x hidden
        atom_hiddens = self.dropout_layer(atom_hiddens)  # num_atoms x hidden

        # Readout
        mol_vecs = []
        for i, (a_start, a_size) in enumerate(a_scope):
            if a_size == 0:
                mol_vecs.append(self.cached_zero_vector)
        #        MPNEncoder.output_x.append(self.cached_zero_vector)
            else:
                cur_hiddens = atom_hiddens.narrow(0, a_start, a_size)
                mol_vec = cur_hiddens  # (num_atoms, hidden_size)

                mol_vec = mol_vec.sum(dim=0) / a_size  # sum 一個分子中所有原子向量 / 原子數量
                mol_vecs.append(mol_vec)
        #        MPNEncoder.output_x.append(mol_vec)

#####################################################
#if len(mol_vecs) < 50:
#        nnnnvecs = torch.stack(MPNEncoder.output_x, dim=0)
#        with open('hidden_vector.pkl', 'wb') as f:
#                pickle.dump(nnnnvecs, f)
#                print('mol_vecs: ', nnnnvecs)
#                print('mol_vecs_size: ', nnnnvecs.shape)
#####################################################

        mol_vecs = torch.stack(mol_vecs, dim=0)  # (num_molecules, hidden_size)

        if self.use_input_features:
            features_batch = features_batch.to(mol_vecs)
            if len(features_batch.shape) == 1:
                features_batch = features_batch.view(
                    [1, features_batch.shape[0]])
            mol_vecs = torch.cat([mol_vecs, features_batch],
                                 dim=1)  # (num_molecules, hidden_size)

        return mol_vecs  # num_molecules x hidden
    def forward(self,
                mol_graph: BatchMolGraph,
                features_batch: List[np.ndarray] = None) -> torch.FloatTensor:
        """
        Encodes a batch of molecular graphs.

        :param mol_graph: A BatchMolGraph representing a batch of molecular graphs.
        :param features_batch: A list of ndarrays containing additional features.
        :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule.
        """
        if self.use_input_features:
            features_batch = torch.from_numpy(np.stack(features_batch)).float()

            if self.args.cuda:
                features_batch = features_batch.cuda()

            if self.features_only:
                return features_batch

        f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope = mol_graph.get_components(
        )

        if self.atom_messages:
            a2a = mol_graph.get_a2a()

        if self.args.cuda or next(self.parameters()).is_cuda:
            f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.cuda(), f_bonds.cuda(
            ), a2b.cuda(), b2a.cuda(), b2revb.cuda()

            if self.atom_messages:
                a2a = a2a.cuda()

        # Input
        if self.atom_messages:
            #print('atom_messages')
            input = self.W_i(f_atoms)  # num_atoms x hidden_size
        else:
            input = self.W_i(f_bonds)  # num_bonds x hidden_size

        message = self.act_func(input)  # num_bonds x hidden_size

        save_depth = []  # wei, update for each step

        # wei, save the information of depth=0 (atomic features)
        padding = torch.zeros(
            (f_atoms.size()[0], f_atoms.size()[1] + self.hidden_size))
        if self.args.cuda:
            padding = padding.cuda()
        padding[:, :f_atoms.size()[1]] = f_atoms
        a_input = padding
        atom_hiddens = self.act_func(self.W_o(a_input))
        atom_hiddens = self.dropout_layer(atom_hiddens)  # num_atoms x hidden
        save_depth.append(atom_hiddens)
        #print('mol_graph:', mol_graph.smiles_batch)

        # Message passing
        for depth in range(self.depth - 1):

            ################ save information of one bond distance from the central atom ################
            if depth == 0:
                a2x = a2a if self.atom_messages else a2b
                nei_a_message = index_select_ND(
                    message, a2x)  # num_atoms x max_num_bonds x hidden
                a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
                a_input = torch.cat([f_atoms, a_message],
                                    dim=1)  # num_atoms x (atom_fdim + hidden)
                atom_hiddens = self.act_func(
                    self.W_o(a_input))  # num_atoms x hidden
                atom_hiddens = self.dropout_layer(
                    atom_hiddens)  # num_atoms x hidden
                save_depth.append(atom_hiddens)
            ################ save information of one bond distance from the central atom ################

            if self.undirected:
                message = (message + message[b2revb]) / 2

            if self.atom_messages:
                nei_a_message = index_select_ND(
                    message, a2a)  # num_atoms x max_num_bonds x hidden
                nei_f_bonds = index_select_ND(
                    f_bonds, a2b)  # num_atoms x max_num_bonds x bond_fdim
                nei_message = torch.cat(
                    (nei_a_message, nei_f_bonds),
                    dim=2)  # num_atoms x max_num_bonds x hidden + bond_fdim
                message = nei_message.sum(
                    dim=1)  # num_atoms x hidden + bond_fdim
            else:
                # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1)
                # message      a_message = sum(nei_a_message)      rev_message
                nei_a_message = index_select_ND(
                    message, a2b)  # num_atoms x max_num_bonds x hidden
                a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
                rev_message = message[b2revb]  # num_bonds x hidden
                message = a_message[b2a] - rev_message  # num_bonds x hidden

            message = self.W_h(message)
            message = self.act_func(input + message)  # num_bonds x hidden_size
            message = self.dropout_layer(message)  # num_bonds x hidden

            # wei, save depth
            a2x = a2a if self.atom_messages else a2b
            nei_a_message = index_select_ND(
                message, a2x)  # num_atoms x max_num_bonds x hidden
            a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
            a_input = torch.cat([f_atoms, a_message],
                                dim=1)  # num_atoms x (atom_fdim + hidden)
            atom_hiddens = self.act_func(
                self.W_o(a_input))  # num_atoms x hidden
            atom_hiddens = self.dropout_layer(
                atom_hiddens)  # num_atoms x hidden
            save_depth.append(atom_hiddens)  # save information of each depth

        ############ origin ############
        #print('message_after_m_passing:\n', message)
        #a2x = a2a if self.atom_messages else a2b
        #nei_a_message = index_select_ND(message, a2x)  # num_atoms x max_num_bonds x hidden
        #a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
        #a_input = torch.cat([f_atoms, a_message], dim=1)  # num_atoms x (atom_fdim + hidden)
        #atom_hiddens = self.act_func(self.W_o(a_input))  # num_atoms x hidden
        #atom_hiddens = self.dropout_layer(atom_hiddens)  # num_atoms x hidden
        ############ origin ############

        # Readout
        atomic_vecs_d0 = []
        atomic_vecs_d1 = []
        atomic_vecs_d2 = []
        atomic_vecs_final = []
        mol_vecs = []
        for i, (a_start, a_size) in enumerate(a_scope):
            if a_size == 0:
                mol_vecs.append(self.cached_zero_vector)
            else:
                cur_hiddens_d0 = save_depth[0].narrow(0, a_start, a_size)
                cur_hiddens_d1 = save_depth[1].narrow(0, a_start, a_size)
                cur_hiddens_d2 = save_depth[2].narrow(0, a_start, a_size)
                cur_hiddens_final = save_depth[-1].narrow(0, a_start, a_size)
                cur_hiddens_mol = save_depth[-1].narrow(0, a_start, a_size)

                atomic_vec_d0 = cur_hiddens_d0
                atomic_vec_d0 = self.padding(atomic_vec_d0)  # padding
                atomic_vec_d1 = cur_hiddens_d1
                atomic_vec_d1 = self.padding(atomic_vec_d1)  # padding
                atomic_vec_d2 = cur_hiddens_d2
                atomic_vec_d2 = self.padding(atomic_vec_d2)  # padding
                atomic_vec_final = cur_hiddens_final
                atomic_vec_final = self.padding(atomic_vec_final)  # padding

                mol_vec = cur_hiddens_mol  # (num_atoms, hidden_size)
                mol_vec = mol_vec.sum(dim=0) / a_size

                atomic_vecs_d0.append(atomic_vec_d0)
                atomic_vecs_d1.append(atomic_vec_d1)
                atomic_vecs_d2.append(atomic_vec_d2)
                atomic_vecs_final.append(atomic_vec_final)
                mol_vecs.append(mol_vec)

        atomic_vecs_d0 = torch.stack(atomic_vecs_d0, dim=0)
        atomic_vecs_d1 = torch.stack(atomic_vecs_d1, dim=0)
        atomic_vecs_d2 = torch.stack(atomic_vecs_d2, dim=0)
        atomic_vecs_final = torch.stack(atomic_vecs_final, dim=0)
        mol_vecs = torch.stack(mol_vecs, dim=0)  # (num_molecules, hidden_size)

        if self.args.cuda:
            atomic_vecs_d0, atomic_vecs_d1, atomic_vecs_d2, atomic_vecs_final, mol_vecs = atomic_vecs_d0.cuda(
            ), atomic_vecs_d1.cuda(), atomic_vecs_d2.cuda(
            ), atomic_vecs_final.cuda(), mol_vecs.cuda()

        #overall_vecs=torch.cat((mol_vecs,mol_vec_molar),dim=0)
        if self.use_input_features:
            features_batch = features_batch.to(mol_vecs)
            if len(features_batch.shape) == 1:
                features_batch = features_batch.view(
                    [1, features_batch.shape[0]])
            mol_vecs = torch.cat(
                [mol_vecs, features_batch],
                dim=1)  # (num_molecules, num_atoms,  hidden_size)

        return atomic_vecs_d0, atomic_vecs_d1, atomic_vecs_d2, atomic_vecs_final, mol_vecs
Exemple #11
0
    def forward(self,
                mol_graph: BatchMolGraph,
                features_batch: List[np.ndarray] = None) -> torch.FloatTensor:
        """
        Encodes a batch of molecular graphs.

        :param mol_graph: A BatchMolGraph representing a batch of molecular graphs.
        :param features_batch: A list of ndarrays containing additional features.
        :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule.
        """
        if self.use_input_features:
            features_batch = torch.from_numpy(np.stack(features_batch)).float()

            if self.args.cuda:
                features_batch = features_batch.cuda()

            if self.features_only:
                return features_batch

        f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope = mol_graph.get_components(
        )

        a2a = mol_graph.get_a2a()

        if self.args.cuda or next(self.parameters()).is_cuda:
            f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.cuda(), f_bonds.cuda(
            ), a2b.cuda(), b2a.cuda(), b2revb.cuda()

            a2a = a2a.cuda()

        # Input
        input = self.W_i(f_atoms)  # num_atoms x hidden_size
        message = self.act_func(input)  # num_bonds x hidden_size

        # Message passing
        for depth in range(self.depth):
            nei_a_message = index_select_ND(
                message, a2a)  # num_atoms x max_num_bonds x hidden
            nei_f_bonds = index_select_ND(
                f_bonds, a2b)  # num_atoms x max_num_bonds x bond_fdim
            nei_message = torch.cat(
                (nei_a_message, nei_f_bonds),
                dim=2)  # num_atoms x max_num_bonds x hidden + bond_fdim

            nei_message = self.W_h(
                nei_message)  # num_atoms x max_num_bonds x hidden
            nei_message = self.act_func(nei_message)
            nei_message = self.dropout_layer(nei_message)

            message = nei_message.sum(dim=1)  # num_atoms x hidden

        # Output step:
        a_input = torch.cat([f_atoms, message],
                            dim=1)  # num_atoms x (atom_fdim + hidden)
        atom_hiddens = self.act_func(self.W_o(a_input))  # num_atoms x hidden
        atom_hiddens = self.dropout_layer(atom_hiddens)  # num_atoms x hidden

        # Readout
        mol_vecs = []
        for i, (a_start, a_size) in enumerate(a_scope):
            if a_size == 0:
                mol_vecs.append(self.cached_zero_vector)
            else:
                cur_hiddens = atom_hiddens.narrow(0, a_start, a_size)
                mol_vec = cur_hiddens  # (num_atoms, hidden_size)

                mol_vec = mol_vec.sum(dim=0) / a_size
                mol_vecs.append(mol_vec)

        mol_vecs = torch.stack(mol_vecs, dim=0)  # (num_molecules, hidden_size)

        if self.use_input_features:
            features_batch = features_batch.to(mol_vecs)
            if len(features_batch.shape) == 1:
                features_batch = features_batch.view(
                    [1, features_batch.shape[0]])
            mol_vecs = torch.cat([mol_vecs, features_batch],
                                 dim=1)  # (num_molecules, hidden_size)

        return mol_vecs  # num_molecules x hidden
Exemple #12
0
    def forward(
        self,
        mol_graph: BatchMolGraph,
        features_batch: List[np.ndarray] = None,
        viz_dir: str = None
    ) -> Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
        """
        Encodes a batch of molecular graphs.

        :param mol_graph: A BatchMolGraph representing a batch of molecular graphs.
        :param features_batch: A list of ndarrays containing additional features.
        :param viz_dir: Directory in which to save visualized attention weights.
        :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule.
        """
        if self.use_input_features:
            features_batch = torch.from_numpy(np.stack(features_batch)).float()

            if self.args.cuda:
                features_batch = features_batch.cuda()

            if self.features_only:
                return features_batch

        f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope = mol_graph.get_components(
        )

        if self.atom_messages:
            a2a = mol_graph.get_a2a()

        if self.args.cuda or next(self.parameters()).is_cuda:
            f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.cuda(), f_bonds.cuda(
            ), a2b.cuda(), b2a.cuda(), b2revb.cuda()

            if self.atom_messages:
                a2a = a2a.cuda()

        if self.learn_virtual_edges:
            atom1_features, atom2_features = f_atoms[b2a], f_atoms[
                b2a[b2revb]]  # each num_bonds x atom_fdim
            ve_score = torch.sum(
                self.lve(atom1_features) * atom2_features, dim=1) + torch.sum(
                    self.lve(atom2_features) * atom1_features, dim=1)
            is_ve_indicator_index = self.atom_fdim  # in current featurization, the first bond feature is 1 or 0 for virtual or not virtual
            num_virtual = f_bonds[:, is_ve_indicator_index].sum()
            straight_through_mask = torch.ones(
                f_bonds.size(0)).cuda() + f_bonds[:, is_ve_indicator_index] * (
                    ve_score -
                    ve_score.detach()) / num_virtual  # normalize for grad norm
            straight_through_mask = straight_through_mask.unsqueeze(1).repeat(
                (1, self.hidden_size))  # num_bonds x hidden_size

        # Input
        if self.atom_messages:
            input = self.W_i(f_atoms)  # num_atoms x hidden_size
        else:
            input = self.W_i(f_bonds)  # num_bonds x hidden_size
        message = self.act_func(input)  # num_bonds x hidden_size

        if self.message_attention:
            b2b = mol_graph.get_b2b(
            )  # Warning: this is O(n_atoms^3) when using virtual edges

            if next(self.parameters()).is_cuda:
                b2b = b2b.cuda()

            message_attention_mask = (b2b !=
                                      0).float()  # num_bonds x max_num_bonds

        if self.global_attention:
            global_attention_mask = torch.zeros(
                mol_graph.n_bonds, mol_graph.n_bonds)  # num_bonds x num_bonds

            for start, length in b_scope:
                for i in range(start, start + length):
                    global_attention_mask[i, start:start + length] = 1

            if next(self.parameters()).is_cuda:
                global_attention_mask = global_attention_mask.cuda()

        # Message passing
        for depth in range(self.depth - 1):
            if self.undirected:
                message = (message + message[b2revb]) / 2

            if self.learn_virtual_edges:
                message = message * straight_through_mask

            if self.message_attention:
                # TODO: Parallelize attention heads
                nei_message = index_select_ND(message, b2b)
                message = message.unsqueeze(1).repeat(
                    (1, nei_message.size(1), 1))  # num_bonds x maxnb x hidden
                attention_scores = [
                    (self.W_ma[i](nei_message) * message).sum(dim=2)
                    for i in range(self.num_heads)
                ]  # num_bonds x maxnb
                attention_scores = [
                    attention_scores[i] * message_attention_mask +
                    (1 - message_attention_mask) * (-1e+20)
                    for i in range(self.num_heads)
                ]  # num_bonds x maxnb
                attention_weights = [
                    F.softmax(attention_scores[i], dim=1)
                    for i in range(self.num_heads)
                ]  # num_bonds x maxnb
                message_components = [
                    nei_message * attention_weights[i].unsqueeze(2).repeat(
                        (1, 1, self.hidden_size))
                    for i in range(self.num_heads)
                ]  # num_bonds x maxnb x hidden
                message_components = [
                    component.sum(dim=1) for component in message_components
                ]  # num_bonds x hidden
                message = torch.cat(message_components,
                                    dim=1)  # num_bonds x num_heads * hidden
            elif self.atom_messages:
                nei_a_message = index_select_ND(
                    message, a2a)  # num_atoms x max_num_bonds x hidden
                nei_f_bonds = index_select_ND(
                    f_bonds, a2b)  # num_atoms x max_num_bonds x bond_fdim
                nei_message = torch.cat(
                    (nei_a_message, nei_f_bonds),
                    dim=2)  # num_atoms x max_num_bonds x hidden + bond_fdim
                message = nei_message.sum(
                    dim=1)  # num_atoms x hidden + bond_fdim
            else:
                # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1)
                # message      a_message = sum(nei_a_message)      rev_message
                nei_a_message = index_select_ND(
                    message, a2b)  # num_atoms x max_num_bonds x hidden
                a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
                rev_message = message[b2revb]  # num_bonds x hidden
                message = a_message[b2a] - rev_message  # num_bonds x hidden

            for lpm in range(self.layers_per_message - 1):
                message = self.W_h[lpm][depth](message)  # num_bonds x hidden
                message = self.act_func(message)
            message = self.W_h[self.layers_per_message - 1][depth](message)

            if self.normalize_messages:
                message = message / message.norm(dim=1, keepdim=True)

            if self.master_node:
                # master_state = self.W_master_in(self.act_func(nei_message.sum(dim=0))) #try something like this to preserve invariance for master node
                # master_state = self.GRU_master(nei_message.unsqueeze(1))
                # master_state = master_state[-1].squeeze(0) #this actually doesn't preserve order invariance anymore
                mol_vecs = [self.cached_zero_vector]
                for start, size in b_scope:
                    if size == 0:
                        continue
                    mol_vec = message.narrow(0, start, size)
                    mol_vec = mol_vec.sum(dim=0) / size
                    mol_vecs += [mol_vec for _ in range(size)]
                master_state = self.act_func(
                    self.W_master_in(torch.stack(
                        mol_vecs, dim=0)))  # num_bonds x hidden_size
                message = self.act_func(
                    input + message +
                    self.W_master_out(master_state))  # num_bonds x hidden_size
            else:
                message = self.act_func(input +
                                        message)  # num_bonds x hidden_size

            if self.global_attention:
                attention_scores = torch.matmul(
                    self.W_ga1(message), message.t())  # num_bonds x num_bonds
                attention_scores = attention_scores * global_attention_mask + (
                    1 - global_attention_mask) * (-1e+20
                                                  )  # num_bonds x num_bonds
                attention_weights = F.softmax(attention_scores,
                                              dim=1)  # num_bonds x num_bonds
                attention_hiddens = torch.matmul(
                    attention_weights, message)  # num_bonds x hidden_size
                attention_hiddens = self.act_func(
                    self.W_ga2(attention_hiddens))  # num_bonds x hidden_size
                attention_hiddens = self.dropout_layer(
                    attention_hiddens)  # num_bonds x hidden_size
                message = message + attention_hiddens  # num_bonds x hidden_size

                if viz_dir is not None:
                    visualize_bond_attention(viz_dir, mol_graph,
                                             attention_weights, depth)

            if self.use_layer_norm:
                message = self.layer_norm(message)

            message = self.dropout_layer(message)  # num_bonds x hidden

        if self.master_node and self.use_master_as_output:
            assert self.hidden_size == self.master_dim
            mol_vecs = []
            for start, size in b_scope:
                if size == 0:
                    mol_vecs.append(self.cached_zero_vector)
                else:
                    mol_vecs.append(master_state[start])
            return torch.stack(mol_vecs, dim=0)

        # Get atom hidden states from message hidden states
        if self.learn_virtual_edges:
            message = message * straight_through_mask

        a2x = a2a if self.atom_messages else a2b
        nei_a_message = index_select_ND(
            message, a2x)  # num_atoms x max_num_bonds x hidden
        a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
        a_input = torch.cat([f_atoms, a_message],
                            dim=1)  # num_atoms x (atom_fdim + hidden)
        atom_hiddens = self.act_func(self.W_o(a_input))  # num_atoms x hidden
        atom_hiddens = self.dropout_layer(atom_hiddens)  # num_atoms x hidden

        if self.deepset:
            atom_hiddens = self.W_s2s_a(atom_hiddens)
            atom_hiddens = self.act_func(atom_hiddens)
            atom_hiddens = self.W_s2s_b(atom_hiddens)

        if self.bert_pretraining:
            atom_preds = self.W_v(atom_hiddens)[
                1:]  # num_atoms x vocab/output size (leave out atom padding)

        # Readout
        if self.set2set:
            # Set up sizes
            batch_size = len(a_scope)
            lengths = [length for _, length in a_scope]
            max_num_atoms = max(lengths)

            # Set up memory from atom features
            memory = torch.zeros(
                batch_size, max_num_atoms,
                self.hidden_size)  # (batch_size, max_num_atoms, hidden_size)
            for i, (start, size) in enumerate(a_scope):
                memory[i, :size] = atom_hiddens.narrow(0, start, size)
            memory_transposed = memory.transpose(
                2, 1)  # (batch_size, hidden_size, max_num_atoms)

            # Create mask (1s for atoms, 0s for not atoms)
            mask = create_mask(lengths, cuda=next(
                self.parameters()).is_cuda)  # (max_num_atoms, batch_size)
            mask = mask.t().unsqueeze(2)  # (batch_size, max_num_atoms, 1)

            # Set up query
            query = torch.ones(
                1, batch_size,
                self.hidden_size)  # (1, batch_size, hidden_size)

            # Move to cuda
            if next(self.parameters()).is_cuda:
                memory, memory_transposed, query = memory.cuda(
                ), memory_transposed.cuda(), query.cuda()

            # Run RNN
            for _ in range(self.set2set_iters):
                # Compute attention weights over atoms in each molecule
                query = query.squeeze(0).unsqueeze(
                    2)  # (batch_size,  hidden_size, 1)
                dot = torch.bmm(memory,
                                query)  # (batch_size, max_num_atoms, 1)
                dot = dot * mask + (1 - mask) * (
                    -1e+20)  # (batch_size, max_num_atoms, 1)
                attention = F.softmax(dot,
                                      dim=1)  # (batch_size, max_num_atoms, 1)

                # Construct next input as attention over memory
                attended = torch.bmm(memory_transposed,
                                     attention)  # (batch_size, hidden_size, 1)
                attended = attended.view(
                    1, batch_size,
                    self.hidden_size)  # (1, batch_size, hidden_size)

                # Run RNN for one step
                query, _ = self.set2set_rnn(
                    attended)  # (1, batch_size, hidden_size)

            # Final RNN output is the molecule encodings
            mol_vecs = query.squeeze(0)  # (batch_size, hidden_size)
        else:
            mol_vecs = []
            # TODO: Maybe do this in parallel with masking rather than looping
            for i, (a_start, a_size) in enumerate(a_scope):
                if a_size == 0:
                    mol_vecs.append(self.cached_zero_vector)
                else:
                    cur_hiddens = atom_hiddens.narrow(0, a_start, a_size)

                    if self.attention:
                        att_w = torch.matmul(self.W_a(cur_hiddens),
                                             cur_hiddens.t())
                        att_w = F.softmax(att_w, dim=1)
                        att_hiddens = torch.matmul(att_w, cur_hiddens)
                        att_hiddens = self.act_func(self.W_b(att_hiddens))
                        att_hiddens = self.dropout_layer(att_hiddens)
                        mol_vec = (cur_hiddens + att_hiddens)

                        if viz_dir is not None:
                            visualize_atom_attention(viz_dir,
                                                     mol_graph.smiles_batch[i],
                                                     a_size, att_w)
                    else:
                        mol_vec = cur_hiddens  # (num_atoms, hidden_size)

                    mol_vec = mol_vec.sum(dim=0) / a_size
                    mol_vecs.append(mol_vec)

            mol_vecs = torch.stack(mol_vecs,
                                   dim=0)  # (num_molecules, hidden_size)

        if self.use_input_features:
            features_batch = features_batch.to(mol_vecs)
            if len(features_batch.shape) == 1:
                features_batch = features_batch.view(
                    [1, features_batch.shape[0]])
            mol_vecs = torch.cat([mol_vecs, features_batch],
                                 dim=1)  # (num_molecules, hidden_size)

        if self.bert_pretraining:
            features_preds = self.W_f(mol_vecs) if hasattr(self,
                                                           'W_f') else None
            return {'features': features_preds, 'vocab': atom_preds}

        return mol_vecs  # num_molecules x hidden
Exemple #13
0
    def forward(
        self,
        mol_graph: BatchMolGraph,
        features_batch: List[np.ndarray] = None,
        viz_dir: str = None
    ) -> Union[torch.FloatTensor, Dict[str, torch.FloatTensor]]:
        """
        Encodes a batch of molecular graphs.

        :param mol_graph: A BatchMolGraph representing a batch of molecular graphs.
        :param features_batch: A list of ndarrays containing additional features.
        :param viz_dir: Directory in which to save visualized attention weights.
        :return: A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule.
        """
        if self.use_input_features:
            features_batch = torch.from_numpy(np.stack(features_batch)).float()

            if self.args.cuda:
                features_batch = features_batch.cuda()

            if self.features_only:
                return features_batch

        f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, uid, targets = mol_graph.get_components(
        )

        if self.atom_messages:
            a2a = mol_graph.get_a2a()

        if self.args.cuda or next(self.parameters()).is_cuda:
            f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.cuda(), f_bonds.cuda(
            ), a2b.cuda(), b2a.cuda(), b2revb.cuda()

            if self.atom_messages:
                a2a = a2a.cuda()

        if self.learn_virtual_edges:
            atom1_features, atom2_features = f_atoms[b2a], f_atoms[
                b2a[b2revb]]  # each num_bonds x atom_fdim
            ve_score = torch.sum(
                self.lve(atom1_features) * atom2_features, dim=1) + torch.sum(
                    self.lve(atom2_features) * atom1_features, dim=1)
            is_ve_indicator_index = self.atom_fdim  # in current featurization, the first bond feature is 1 or 0 for virtual or not virtual
            num_virtual = f_bonds[:, is_ve_indicator_index].sum()
            straight_through_mask = torch.ones(
                f_bonds.size(0)).cuda() + f_bonds[:, is_ve_indicator_index] * (
                    ve_score -
                    ve_score.detach()) / num_virtual  # normalize for grad norm
            straight_through_mask = straight_through_mask.unsqueeze(1).repeat(
                (1, self.hidden_size))  # num_bonds x hidden_size

        # Input
        if self.atom_messages:
            input = self.W_i(f_atoms)  # num_atoms x hidden_size
        else:
            input = self.W_i(f_bonds)  # num_bonds x hidden_size
        message = self.act_func(input)  # num_bonds x hidden_size

        if self.message_attention:
            b2b = mol_graph.get_b2b(
            )  # Warning: this is O(n_atoms^3) when using virtual edges

            if next(self.parameters()).is_cuda:
                b2b = b2b.cuda()

            message_attention_mask = (b2b !=
                                      0).float()  # num_bonds x max_num_bonds

        if self.global_attention:
            global_attention_mask = torch.zeros(
                mol_graph.n_bonds, mol_graph.n_bonds)  # num_bonds x num_bonds

            for start, length in b_scope:
                for i in range(start, start + length):
                    global_attention_mask[i, start:start + length] = 1

            if next(self.parameters()).is_cuda:
                global_attention_mask = global_attention_mask.cuda()

        # Message passing
        for depth in range(self.depth - 1):
            if self.undirected:
                message = (message + message[b2revb]) / 2

            if self.learn_virtual_edges:
                message = message * straight_through_mask

            if self.message_attention:
                # TODO: Parallelize attention heads
                nei_message = index_select_ND(message, b2b)
                message = message.unsqueeze(1).repeat(
                    (1, nei_message.size(1), 1))  # num_bonds x maxnb x hidden
                attention_scores = [
                    (self.W_ma[i](nei_message) * message).sum(dim=2)
                    for i in range(self.num_heads)
                ]  # num_bonds x maxnb
                attention_scores = [
                    attention_scores[i] * message_attention_mask +
                    (1 - message_attention_mask) * (-1e+20)
                    for i in range(self.num_heads)
                ]  # num_bonds x maxnb
                attention_weights = [
                    F.softmax(attention_scores[i], dim=1)
                    for i in range(self.num_heads)
                ]  # num_bonds x maxnb
                message_components = [
                    nei_message * attention_weights[i].unsqueeze(2).repeat(
                        (1, 1, self.hidden_size))
                    for i in range(self.num_heads)
                ]  # num_bonds x maxnb x hidden
                message_components = [
                    component.sum(dim=1) for component in message_components
                ]  # num_bonds x hidden
                message = torch.cat(message_components,
                                    dim=1)  # num_bonds x num_heads * hidden
            elif self.atom_messages:
                nei_a_message = index_select_ND(
                    message, a2a)  # num_atoms x max_num_bonds x hidden
                nei_f_bonds = index_select_ND(
                    f_bonds, a2b)  # num_atoms x max_num_bonds x bond_fdim
                nei_message = torch.cat(
                    (nei_a_message, nei_f_bonds),
                    dim=2)  # num_atoms x max_num_bonds x hidden + bond_fdim
                message = nei_message.sum(
                    dim=1)  # num_atoms x hidden + bond_fdim
            else:
                # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1)
                # message      a_message = sum(nei_a_message)      rev_message
                nei_a_message = index_select_ND(
                    message, a2b)  # num_atoms x max_num_bonds x hidden
                a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
                rev_message = message[b2revb]  # num_bonds x hidden
                message = a_message[b2a] - rev_message  # num_bonds x hidden

            for lpm in range(self.layers_per_message - 1):
                message = self.W_h[lpm][depth](message)  # num_bonds x hidden
                message = self.act_func(message)
            message = self.W_h[self.layers_per_message - 1][depth](message)

            if self.normalize_messages:
                message = message / message.norm(dim=1, keepdim=True)

            if self.master_node:
                # master_state = self.W_master_in(self.act_func(nei_message.sum(dim=0))) #try something like this to preserve invariance for master node
                # master_state = self.GRU_master(nei_message.unsqueeze(1))
                # master_state = master_state[-1].squeeze(0) #this actually doesn't preserve order invariance anymore
                mol_vecs = [self.cached_zero_vector]
                for start, size in b_scope:
                    if size == 0:
                        continue
                    mol_vec = message.narrow(0, start, size)
                    mol_vec = mol_vec.sum(dim=0) / size
                    mol_vecs += [mol_vec for _ in range(size)]
                master_state = self.act_func(
                    self.W_master_in(torch.stack(
                        mol_vecs, dim=0)))  # num_bonds x hidden_size
                message = self.act_func(
                    input + message +
                    self.W_master_out(master_state))  # num_bonds x hidden_size
            else:
                message = self.act_func(input +
                                        message)  # num_bonds x hidden_size

            if self.global_attention:
                attention_scores = torch.matmul(
                    self.W_ga1(message), message.t())  # num_bonds x num_bonds
                attention_scores = attention_scores * global_attention_mask + (
                    1 - global_attention_mask) * (-1e+20
                                                  )  # num_bonds x num_bonds
                attention_weights = F.softmax(attention_scores,
                                              dim=1)  # num_bonds x num_bonds
                attention_hiddens = torch.matmul(
                    attention_weights, message)  # num_bonds x hidden_size
                attention_hiddens = self.act_func(
                    self.W_ga2(attention_hiddens))  # num_bonds x hidden_size
                attention_hiddens = self.dropout_layer(
                    attention_hiddens)  # num_bonds x hidden_size
                message = message + attention_hiddens  # num_bonds x hidden_size

                if viz_dir is not None:
                    visualize_bond_attention(viz_dir, mol_graph,
                                             attention_weights, depth)

            # Bond attention during message passing
            if self.bond_attention:
                alphas = [
                    self.act_func(self.W_ap[j](message))
                    for j in range(self.attention_pooling_heads)
                ]  # num_bonds x hidden_size
                alphas = [
                    self.V_ap[j](alphas[j])
                    for j in range(self.attention_pooling_heads)
                ]  # num_bonds x 1
                alphas = [
                    F.softmax(alphas[j], dim=0)
                    for j in range(self.attention_pooling_heads)
                ]  # num_bonds x 1
                alphas = [
                    alphas[j].squeeze(1)
                    for j in range(self.attention_pooling_heads)
                ]  # num_bonds
                att_hiddens = [
                    torch.mul(alphas[j], message.t()).t()
                    for j in range(self.attention_pooling_heads)
                ]  # num_bonds x hidden_size
                att_hiddens = sum(att_hiddens) / float(
                    self.attention_pooling_heads)  # num_bonds x hidden_size

                message = att_hiddens  # att_hiddens is the new message

                # Create visualizations (time-consuming, best only done on test set)
                if self.attention_viz and depth == self.depth - 2:  # Visualize at end of primary message passing phase

                    bond_analysis = dict()
                    for dict_key in range(8):
                        bond_analysis[dict_key] = []

                    # Loop through the individual graphs in the batch
                    for i, (a_start, a_size) in enumerate(a_scope):
                        if a_size == 0:  # Skip over empty graphs
                            continue
                        else:
                            for j in range(self.attention_pooling_heads):
                                atoms = f_atoms[
                                    a_start:a_start +
                                    a_size, :].cpu().numpy()  # Atom features
                                bonds1 = a2b[a_start:a_start +
                                             a_size, :]  # a2b
                                bonds2 = b2a[b_scope[i][0]:b_scope[i][0] +
                                             b_scope[i][1]]  # b2a
                                bonds_dict = {
                                }  # Dictionary to keep track of atoms that bonds are connecting
                                for k in range(
                                        len(bonds2)):  # Collect info from b2a
                                    bonds_dict[
                                        k + 1] = bonds2[k].item() - a_start + 1
                                for k in range(bonds1.shape[0]
                                               ):  # Collect info from a2b
                                    for m in range(bonds1.shape[1]):
                                        bond_num = bonds1[
                                            k, m].item() - b_scope[i][0] + 1
                                        if bond_num > 0:
                                            bonds_dict[bond_num] = (
                                                bonds_dict[bond_num], k + 1)

                                # Save weights for this graph
                                weights = alphas[j].cpu().data.numpy(
                                )[b_scope[i][0]:b_scope[i][0] + b_scope[i][1]]
                                id_number = uid[i].item()  # Graph uid
                                label = targets[i].item()  # Graph label
                                viz_dir = self.args.save_dir + '/' + 'bond_attention_visualizations'  # Folder
                                os.makedirs(
                                    viz_dir, exist_ok=True
                                )  # Only create folder if not already exist

                                label, num_subgraphs, num_type_2_connections, num_type_1_isolated, num_type_2_isolated = visualize_bond_attention_pooling(
                                    atoms, bonds_dict, weights, id_number,
                                    label, viz_dir)

                                # Write analysis results
                                if not os.path.exists(
                                        self.args.save_dir + '/' +
                                        'attention_analysis.txt'):
                                    f = open(
                                        self.args.save_dir + '/' +
                                        'attention_analysis.txt', 'w')
                                    f.write(
                                        '# Category    Subgraphs    Type 2 Bonds    Type 1 Isolated    '
                                        'Type 2 Isolated' + '\n')
                                else:
                                    f = open(
                                        self.args.save_dir + '/' +
                                        'attention_analysis.txt', 'a')
                                f.write(
                                    str(label) + '    ' + str(num_subgraphs) +
                                    '    ' + str(num_type_2_connections) +
                                    '    ' + str(num_type_1_isolated) +
                                    '    ' + str(num_type_2_isolated) + '\n')

                                bond_analysis[label].append(num_subgraphs)
                                bond_analysis[label +
                                              2].append(num_type_2_connections)
                                bond_analysis[label +
                                              4].append(num_type_1_isolated)
                                bond_analysis[label +
                                              6].append(num_type_2_isolated)

                    # # Write analysis results
                    # if not os.path.exists(self.args.save_dir + '/' + 'attention_analysis.txt'):
                    #     f = open(self.args.save_dir + '/' + 'attention_analysis.txt', 'w')
                    #     f.write('# Category 0 Subgraphs    Category 1 Subgraphs    '
                    #             'Category 0 Type 2 Bonds    Category 1 Type 2 Bonds    '
                    #             'Category 0 Type 1 Isolated    Category 1 Type 1 Isolated    '
                    #             'Category 0 Type 2 Isolated    Category 1 Type 2 Isolated' '\n')
                    # else:
                    #     f = open(self.args.save_dir + '/' + 'attention_analysis.txt', 'a')
                    # f.write(str(np.mean(np.array(bond_analysis[0]))) + '    ' +
                    #         str(np.mean(np.array(bond_analysis[1]))) + '    ' +
                    #         str(np.mean(np.array(bond_analysis[2]))) + '    ' +
                    #         str(np.mean(np.array(bond_analysis[3]))) + '    ' +
                    #         str(np.mean(np.array(bond_analysis[4]))) + '    ' +
                    #         str(np.mean(np.array(bond_analysis[5]))) + '    ' +
                    #         str(np.mean(np.array(bond_analysis[6]))) + '    ' +
                    #         str(np.mean(np.array(bond_analysis[7]))) + '    ' + '\n')
                    # f.close()

            if self.use_layer_norm:
                message = self.layer_norm(message)

            message = self.dropout_layer(message)  # num_bonds x hidden

        if self.master_node and self.use_master_as_output:
            assert self.hidden_size == self.master_dim
            mol_vecs = []
            for start, size in b_scope:
                if size == 0:
                    mol_vecs.append(self.cached_zero_vector)
                else:
                    mol_vecs.append(master_state[start])
            return torch.stack(mol_vecs, dim=0)

        # Get atom hidden states from message hidden states
        if self.learn_virtual_edges:
            message = message * straight_through_mask

        if self.bond_attention_pooling:
            nei_a_message = index_select_ND(
                message, a2b)  # num_atoms x max_num_bonds x hidden
            a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
            rev_message = message[b2revb]  # num_bonds x hidden
            message = a_message[b2a] - rev_message  # num_bonds x hidden
            message = self.act_func(self.W_o(message))  # num_bonds x hidden
            bond_hiddens = self.dropout_layer(message)  # num_bonds x hidden

        else:
            a2x = a2a if self.atom_messages else a2b
            nei_a_message = index_select_ND(
                message, a2x)  # num_atoms x max_num_bonds x hidden
            a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
            a_input = torch.cat([f_atoms, a_message],
                                dim=1)  # num_atoms x (atom_fdim + hidden)
            atom_hiddens = self.act_func(
                self.W_o(a_input))  # num_atoms x hidden
            atom_hiddens = self.dropout_layer(
                atom_hiddens)  # num_atoms x hidden

        if self.deepset:
            atom_hiddens = self.W_s2s_a(atom_hiddens)
            atom_hiddens = self.act_func(atom_hiddens)
            atom_hiddens = self.W_s2s_b(atom_hiddens)

        if self.bert_pretraining:
            atom_preds = self.W_v(atom_hiddens)[
                1:]  # num_atoms x vocab/output size (leave out atom padding)

        # Readout
        if self.set2set:
            # Set up sizes
            batch_size = len(a_scope)
            lengths = [length for _, length in a_scope]
            max_num_atoms = max(lengths)

            # Set up memory from atom features
            memory = torch.zeros(
                batch_size, max_num_atoms,
                self.hidden_size)  # (batch_size, max_num_atoms, hidden_size)
            for i, (start, size) in enumerate(a_scope):
                memory[i, :size] = atom_hiddens.narrow(0, start, size)
            memory_transposed = memory.transpose(
                2, 1)  # (batch_size, hidden_size, max_num_atoms)

            # Create mask (1s for atoms, 0s for not atoms)
            mask = create_mask(lengths, cuda=next(
                self.parameters()).is_cuda)  # (max_num_atoms, batch_size)
            mask = mask.t().unsqueeze(2)  # (batch_size, max_num_atoms, 1)

            # Set up query
            query = torch.ones(
                1, batch_size,
                self.hidden_size)  # (1, batch_size, hidden_size)

            # Move to cuda
            if next(self.parameters()).is_cuda:
                memory, memory_transposed, query = memory.cuda(
                ), memory_transposed.cuda(), query.cuda()

            # Run RNN
            for _ in range(self.set2set_iters):
                # Compute attention weights over atoms in each molecule
                query = query.squeeze(0).unsqueeze(
                    2)  # (batch_size,  hidden_size, 1)
                dot = torch.bmm(memory,
                                query)  # (batch_size, max_num_atoms, 1)
                dot = dot * mask + (1 - mask) * (
                    -1e+20)  # (batch_size, max_num_atoms, 1)
                attention = F.softmax(dot,
                                      dim=1)  # (batch_size, max_num_atoms, 1)

                # Construct next input as attention over memory
                attended = torch.bmm(memory_transposed,
                                     attention)  # (batch_size, hidden_size, 1)
                attended = attended.view(
                    1, batch_size,
                    self.hidden_size)  # (1, batch_size, hidden_size)

                # Run RNN for one step
                query, _ = self.set2set_rnn(
                    attended)  # (1, batch_size, hidden_size)

            # Final RNN output is the molecule encodings
            mol_vecs = query.squeeze(0)  # (batch_size, hidden_size)
        else:
            mol_vecs = []

            if self.bond_attention_pooling:
                for i, (b_start, b_size) in enumerate(b_scope):
                    if b_size == 0:
                        mol_vecs.append(self.cached_zero_vector)
                    else:
                        cur_hiddens = bond_hiddens.narrow(0, b_start, b_size)
                        alphas = [
                            self.act_func(self.W_ap[j](cur_hiddens))
                            for j in range(self.attention_pooling_heads)
                        ]
                        alphas = [
                            self.V_ap[j](alphas[j])
                            for j in range(self.attention_pooling_heads)
                        ]
                        alphas = [
                            F.softmax(alphas[j], dim=0)
                            for j in range(self.attention_pooling_heads)
                        ]
                        alphas = [
                            alphas[j].squeeze(1)
                            for j in range(self.attention_pooling_heads)
                        ]
                        att_hiddens = [
                            torch.mul(alphas[j], cur_hiddens.t()).t()
                            for j in range(self.attention_pooling_heads)
                        ]
                        att_hiddens = sum(att_hiddens) / float(
                            self.attention_pooling_heads)

                        mol_vec = att_hiddens

                        mol_vec = mol_vec.sum(dim=0)
                        mol_vecs.append(mol_vec)

                        if self.attention_viz:

                            for j in range(self.attention_pooling_heads):

                                atoms = f_atoms[
                                    a_scope[i][0]:a_scope[i][0] +
                                    a_scope[i][1], :].cpu().numpy()
                                bonds1 = a2b[a_scope[i][0]:a_scope[i][0] +
                                             a_scope[i][1], :]
                                bonds2 = b2a[b_scope[i][0]:b_scope[i][0] +
                                             b_scope[i][1]]
                                bonds_dict = {}
                                for k in range(len(bonds2)):
                                    bonds_dict[k + 1] = bonds2[k].item(
                                    ) - a_scope[i][0] + 1
                                for k in range(bonds1.shape[0]):
                                    for m in range(bonds1.shape[1]):
                                        bond_num = bonds1[
                                            k, m].item() - b_scope[i][0] + 1
                                        if bond_num > 0:
                                            bonds_dict[bond_num] = (
                                                bonds_dict[bond_num], k + 1)

                                weights = alphas[j].cpu().data.numpy()
                                id_number = uid[i].item()
                                label = targets[i].item()
                                viz_dir = self.args.save_dir + '/' + 'bond_attention_pooling_visualizations'
                                os.makedirs(viz_dir, exist_ok=True)

                                visualize_bond_attention_pooling(
                                    atoms, bonds_dict, weights, id_number,
                                    label, viz_dir)

            else:
                # TODO: Maybe do this in parallel with masking rather than looping
                for i, (a_start, a_size) in enumerate(a_scope):
                    if a_size == 0:
                        mol_vecs.append(self.cached_zero_vector)
                    else:
                        cur_hiddens = atom_hiddens.narrow(0, a_start, a_size)

                        if self.attention:
                            att_w = torch.matmul(self.W_a(cur_hiddens),
                                                 cur_hiddens.t())
                            att_w = F.softmax(att_w, dim=1)
                            att_hiddens = torch.matmul(att_w, cur_hiddens)
                            att_hiddens = self.act_func(self.W_b(att_hiddens))
                            att_hiddens = self.dropout_layer(att_hiddens)
                            mol_vec = (cur_hiddens + att_hiddens)

                            if viz_dir is not None:
                                visualize_atom_attention(
                                    viz_dir, mol_graph.smiles_batch[i], a_size,
                                    att_w)

                        elif self.attention_pooling:
                            alphas = [
                                self.act_func(self.W_ap[j](cur_hiddens))
                                for j in range(self.attention_pooling_heads)
                            ]
                            alphas = [
                                self.V_ap[j](alphas[j])
                                for j in range(self.attention_pooling_heads)
                            ]
                            alphas = [
                                F.softmax(alphas[j], dim=0)
                                for j in range(self.attention_pooling_heads)
                            ]
                            alphas = [
                                alphas[j].squeeze(1)
                                for j in range(self.attention_pooling_heads)
                            ]
                            att_hiddens = [
                                torch.mul(alphas[j], cur_hiddens.t()).t()
                                for j in range(self.attention_pooling_heads)
                            ]
                            att_hiddens = sum(att_hiddens) / float(
                                self.attention_pooling_heads)

                            mol_vec = att_hiddens

                            if self.attention_viz:

                                for j in range(self.attention_pooling_heads):

                                    atoms = f_atoms[a_start:a_start +
                                                    a_size, :].cpu().numpy()
                                    weights = alphas[j].cpu().data.numpy()
                                    id_number = uid[i].item()
                                    label = targets[i].item()
                                    viz_dir = self.args.save_dir + '/' + 'attention_pooling_visualizations'
                                    os.makedirs(viz_dir, exist_ok=True)

                                    visualize_attention_pooling(
                                        atoms, weights, id_number, label,
                                        viz_dir)
                                    analyze_attention_pooling(
                                        atoms, weights, id_number, label,
                                        viz_dir)

                        else:
                            mol_vec = cur_hiddens  # (num_atoms, hidden_size)

                        mol_vec = mol_vec.sum(
                            dim=0
                        ) / a_size  #TODO: remove a_size division in case of attention_pooling
                        mol_vecs.append(mol_vec)

            mol_vecs = torch.stack(mol_vecs,
                                   dim=0)  # (num_molecules, hidden_size)

        if self.use_input_features:
            features_batch = features_batch.to(mol_vecs)
            if len(features_batch.shape) == 1:
                features_batch = features_batch.view(
                    [1, features_batch.shape[0]])
            mol_vecs = torch.cat([mol_vecs, features_batch],
                                 dim=1)  # (num_molecules, hidden_size)

        if self.bert_pretraining:
            features_preds = self.W_f(mol_vecs) if hasattr(self,
                                                           'W_f') else None
            return {'features': features_preds, 'vocab': atom_preds}

        return mol_vecs  # num_molecules x hidden
Exemple #14
0
    def forward(self, mol_graph: BatchMolGraph, features_batch: T.Tensor):
        f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope = mol_graph.get_components(
        )

        if self.atom_messages:
            a2a = mol_graph.get_a2a()
            input = self.W_i(f_atoms)
        else:
            input = self.W_i(f_bonds)
        message = self.act_func(input)

        # Message passing
        for depth in range(self.depth - 1):
            if self.undirected:
                message = (message + message[b2revb]) / 2

            if self.atom_messages:
                nei_a_message = index_select_ND(
                    message, a2a)  # num_atoms x max_num_bonds x hidden
                nei_f_bonds = index_select_ND(
                    f_bonds, a2b)  # num_atoms x max_num_bonds x bond_fdim
                nei_message = T.cat(
                    (nei_a_message, nei_f_bonds),
                    dim=2)  # num_atoms x max_num_bonds x hidden + bond_fdim
                message = nei_message.sum(
                    dim=1)  # num_atoms x hidden + bond_fdim
            else:
                # m(a1 -> a2) = [sum_{a0 \in nei(a1)} m(a0 -> a1)] - m(a2 -> a1)
                # message      a_message = sum(nei_a_message)      rev_message
                nei_a_message = index_select_ND(
                    message, a2b)  # num_atoms x max_num_bonds x hidden
                a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
                rev_message = message[b2revb]  # num_bonds x hidden
                message = a_message[b2a] - rev_message  # num_bonds x hidden

            message = self.W_h(message)
            message = self.act_func(input + message)  # num_bonds x hidden_size
            message = self.dropout_layer(message)  # num_bonds x hidden

        a2x = a2a if self.atom_messages else a2b
        nei_a_message = index_select_ND(
            message, a2x)  # num_atoms x max_num_bonds x hidden
        a_message = nei_a_message.sum(dim=1)  # num_atoms x hidden
        a_input = T.cat([f_atoms, a_message],
                        dim=1)  # num_atoms x (atom_fdim + hidden)
        atom_hiddens = self.act_func(self.W_o(a_input))  # num_atoms x hidden
        atom_hiddens = self.dropout_layer(atom_hiddens)  # num_atoms x hidden

        # Readout
        mol_vecs = []
        for i, (a_start, a_size) in enumerate(a_scope):
            if a_size == 0:
                mol_vecs.append(self.cached_zero_vector)
            else:
                cur_hiddens = atom_hiddens.narrow(0, a_start, a_size)
                mol_vec = cur_hiddens  # (num_atoms, hidden_size)

                mol_vec = mol_vec.sum(dim=0) / a_size
                #mol_vec = mol_vec.mean(0).values
                mol_vecs.append(mol_vec)

        mol_vecs = T.stack(mol_vecs, dim=0)  # (num_molecules, hidden_size)

        return mol_vecs