def __init__(self, dataset: ScaffoldMolDataset, sampler: ScaffoldMolSampler, k: int = 5, p: float = 0.5, num_workers: int = 0, ms=MoleculeSpec.get_default()): """ Construct ScaffoldLoader from the given dataset Args: dataset (Dataset): The Dataset to be loaded sampler (BatchSampler): The batch sampler k (int): Number of importance samples, default to 5 p (float): Degree of uncertainty during route sampling, should be in (0, 1), default to 0.5 """ self.k = k # pylint: disable=invalid-name self.p = p # pylint: disable=invalid-name self.ms = ms super(DataLoader, self).__init__(dataset, collate_fn=self._collate_fn, batch_sampler=sampler, num_workers=num_workers)
def _sample_ordering(mol, scaffold_nodes, k, p, ms=MoleculeSpec.get_default()): """Sampling decoding routes of a given molecule `mol` Args: mol (Chem.Mol): the given molecule (type: Chem.Mol) scaffold_nodes (np.ndarray): the nodes marked as scaffold k (int): The number of importance samples p (float): Degree of uncertainty during route sampling, should be in (0, 1) ms (mol_spec.MoleculeSpec) Returns: route_list (np.ndarray): route_list[i][j] the index of the atom reached at step j in sample i step_ids_list (np.ndarray): step_ids_list[i][j] the step at which atom j is reach at sample i logp_list (np.ndarray): logp_list[i] - the log-likelihood value of route i """ # build graph atom_types = [] for atom in mol.GetAtoms(): atom_types.append(ms.get_atom_type(atom)) atom_ranks = [] for r in Chem.CanonicalRankAtoms(mol): atom_ranks.append(r) atom_ranks = np.array(atom_ranks) bonds = [] for b in mol.GetBonds(): idx_1, idx_2 = b.GetBeginAtomIdx(), b.GetEndAtomIdx() bonds.append([idx_1, idx_2]) # build nx graph graph = nx.Graph() graph.add_nodes_from(range(len(atom_ranks))) graph.add_edges_from(bonds) route_list = [] step_ids_list = [] logp_list = [] for _ in range(k): step_ids, log_p = _traverse(graph=graph, atom_ranks=atom_ranks, scaffold_nodes=scaffold_nodes, p=p) step_ids_list.append(step_ids) step_ids = np.argsort(step_ids) route_list.append(step_ids) logp_list.append(log_p) # cast to numpy array route_list = np.array(route_list, dtype=np.int32) step_ids_list = np.array(step_ids_list, dtype=np.int32) logp_list = np.array(logp_list, dtype=np.float32) return route_list, step_ids_list, logp_list
def get_data_loader_full(scaffold_network_loc: str, molecule_smiles_loc: str, batch_size: int, num_iterations: int, num_workers: int, k: int, p: float, ms: MoleculeSpec = MoleculeSpec.get_default() ) -> DataLoader: """Helper function for getting the dataloader Args: scaffold_network_loc (str): The location to the network file molecule_smiles_loc (str): The location to the file containing molecular SMILES batch_size (int): Batch size for training num_iterations (int): The number of iterations to train num_workers (int): The number of workers used for data loading k (int) p (float) ms (MoleculeSpec, optional) Returns: t.Tuple[DataLoader, DataLoader]: The loader for training data and test data """ # Get dataset # pylint: disable=invalid-name db = ScaffoldMolDataset(scaffold_network_loc, molecule_smiles_loc) # batch_size assert batch_size % 2 == 0 # Get sampler sampler = ScaffoldMolSampler(db, (batch_size // 2, batch_size // 2), num_iterations, None, None, None) # Get DataLoaders loader = DataLoader(db, sampler, k, p, num_workers, ms) return loader
def get_data_loader_full(scaffold_network_loc, molecule_smiles_loc, batch_size, num_iterations, num_workers, k, p, ms=MoleculeSpec.get_default()): """Helper function for getting the dataloader Args: scaffold_network_loc (str): The location to the network file molecule_smiles_loc (str): The location to the file containing molecular SMILES batch_size (int): Batch size for training num_iterations (int): The number of iterations to train num_workers (int): The number of workers used for data loading k (int) p (float) ms (MoleculeSpec, optional) Returns: t.Tuple[DataLoader, DataLoader]: The loader for training data and test data """ # Get dataset db = ScaffoldMolDataset(scaffold_network_loc=scaffold_network_loc, molecule_smiles_loc=molecule_smiles_loc) # batch_size assert batch_size % 2 == 0 # Get sampler sampler = ScaffoldMolSampler(dataset=db, batch_size=(batch_size // 2, batch_size // 2), num_iterations=num_iterations, exclude_ids_loc=None, training=None, split_type=None) # Get DataLoaders loader = DataLoader(db, sampler, k, p, num_workers, ms) return loader
def pack_encoder( mol_array, ms=MoleculeSpec.get_default()) -> t.Tuple[torch.Tensor]: """ Pack and expand information in mol_array in order to feed into graph encoders (The encoder version of the function `pack()`) Args: mol_array (torch.Tensor): input molecule array size [batch_size, max_num_steps, 5], type: `torch.long` where 5 = atom_type + begin_ids + end_ids + bond_type ms (mol_spec.MoleculeSpec) Returns: atom_types (torch.Tensor): Atom type information packed into a single vector, type: torch.long is_scaffold (torch.Tensor): Whether the corresponding atom is contained in the scaffold, type: torch.long bond_info (torch.Tensor): Bond type information packed into a single matrix type: torch.long, shape: [-1, 3], 3 = begin_ids + end_ids + bond_type block_ids, atom_ids (torch.Tensor): type: torch.long, shape: [num_total_atoms, ] """ # get device info device = mol_array.device # magical numbers I_ATOM_TYPE, I_BEGIN_IDS, I_END_IDS, I_BOND_TYPE, I_IS_SCAFFOLD = range(5) # The number of decoding steps required for each input molecule # size: [batch_size, ] num_steps = mol_array[:, :, I_END_IDS].ge(0).long().sum(-1) # Get molecule id and step id for each (unexpanded) atom/node # Example: # if we have num_steps = [4, 2, 5] and batch_size = 3, then # molecule: | 0 | 1 | 2 | # mol_ids = [0 0 0 0 1 1 2 2 2 2 2] # step_ids = [0 1 2 3 0 1 0 1 2 3 4] mol_ids, step_ids = rep_and_range(num_steps) mol_array_packed = mol_array[mol_ids, step_ids, :] # Get the expanded atom type atom_types, is_scaffold = (mol_array_packed[:, I_ATOM_TYPE], mol_array_packed[:, I_IS_SCAFFOLD]) # binary vector with int values is_connect = atom_types.ge(0).long() is_connect_index = is_connect.nonzero().squeeze() num_atoms = torch_scatter.scatter_add(is_connect, dim=0, index=mol_ids) atom_types = atom_types[is_connect_index] is_scaffold = is_scaffold[is_connect_index] block_ids, atom_ids = rep_and_range(num_atoms) # Get last append mask # Locations of the latest appended atoms last_append_loc = torch.cumsum(num_atoms, dim=0) - 1 # Initialize last_append_mask as zeros last_append_mask = torch.full_like(is_scaffold, 0) # Fill the latest appended atoms as one last_append_mask[last_append_loc] = 1 # Get (packed) bond information # size: [-1, 3], where 3=begin_ids, end_ids, bond_type bond_info = mol_array_packed[:, [I_BEGIN_IDS, I_END_IDS, I_BOND_TYPE]] # adjust begin_ids and end_ids for each bond num_atoms_cumsum = pad_first(torch.cumsum(num_atoms, dim=0)[:-1]) I_BOND_BEGIN_IDS = 0 _filter = bond_info[:, I_BOND_BEGIN_IDS].ge(0).nonzero().squeeze() _shift = num_atoms_cumsum[mol_ids] _shift = torch.stack([ _shift, ] * 2 + [torch.zeros_like(_shift)], dim=1) bond_info = (bond_info + _shift)[_filter, :] # symmetrize bond_info bond_info = torch.cat([bond_info, bond_info[:, [1, 0, 2]]], ) # labels for artificial bonds and atoms (I_BOND_REMOTE_2, I_BOND_REMOTE_3, I_BOND_ATOM_SELF) = range(ms.num_bond_types, ms.num_bond_types + 3) # artificial bond type: remote connection indices = bond_info[:, :2].cpu().numpy() size = atom_types.size(0) d_indices_2, d_indices_3 = get_remote_connection(indices, size) # pylint: disable=not-callable d_indices_2, d_indices_3 = (torch.tensor(d_indices_2, dtype=torch.long, device=device), torch.tensor(d_indices_3, dtype=torch.long, device=device)) bond_type = torch.full([d_indices_2.size(0), 1], I_BOND_REMOTE_2, dtype=torch.long, device=device) bond_info_remote_2 = torch.cat([d_indices_2, bond_type], dim=-1) bond_type = torch.full([d_indices_3.size(0), 1], I_BOND_REMOTE_3, dtype=torch.long, device=device) bond_info_remote_3 = torch.cat([d_indices_3, bond_type], dim=-1) bond_info = torch.cat([bond_info, bond_info_remote_2, bond_info_remote_3], dim=0) # artificial bond type: self connection begin_ids = end_ids = torch.arange(atom_types.size(0), dtype=torch.long, device=atom_types.device) bond_type = torch.full_like(end_ids, I_BOND_ATOM_SELF) bond_info_self = torch.stack([begin_ids, end_ids, bond_type], dim=-1) bond_info = torch.cat([bond_info, bond_info_self], dim=0) return (atom_types, is_scaffold, bond_info, last_append_mask, block_ids, atom_ids)
def pack_decoder( mol_array: torch.Tensor, ms=MoleculeSpec.get_default() ) -> t.Tuple[torch.Tensor, ...]: """ Pack and expand information in mol_array in order to feed into the neural network Args: mol_array (torch.Tensor): input molecule array, size [batch_size, max_num_steps, 5], type: `torch.long` 5 = atom_type + begin_ids + end_ids + bond_type + is_scaffold ms (mol_spec.MoleculeSpec) Returns: atom_types (torch.Tensor): Atom type information packed into a single vector, type: torch.long is_scaffold (torch.Tensor): Whether the corresponding atom is contained in the scaffold, type: torch.long bond_info (torch.Tensor): Bond type information packed into a single matrix, type: torch.long, shape: [-1, 3] 3 = begin_ids + end_ids + bond_type actions (torch.Tensor): The action to carry out at each step, type: torch.long, shape: [-1, 5] 5 = action_type + atom_type + bond_type + append_loc + connect_loc mol_ids, step_ids, block_ids (torch.Tensor): Index information, type: torch.long """ # get device info device = mol_array.device # magic numbers I_ATOM_TYPE, I_BEGIN_IDS, I_END_IDS, I_BOND_TYPE, I_IS_SCAFFOLD = range(5) # The number of decoding steps required for the entire molecule # size: [batch_size, ] num_total_steps = mol_array[:, :, I_END_IDS].ge(0).long().sum(-1) # The number of steps required for the generation of scaffold, # size: [batch_size, ] num_scaffold_steps = mol_array[:, :, I_IS_SCAFFOLD].eq(1).long().sum(-1) # The number of steps required for the generation of side chains, # size: [batch_size, ] # NOTE: The additional 1 step is the termination step num_steps = num_total_steps - num_scaffold_steps + 1 # Get molecule id and step id for each (unexpanded) atom/node # Example: # if we have num_steps = [4, 2, 5] and batch_size = 3, then # molecule: | 0 | 1 | 2 | # mol_ids = [0 0 0 0 1 1 2 2 2 2 2] # step_ids = [0 1 2 3 0 1 0 1 2 3 4] mol_ids, step_ids = rep_and_range(num_steps) # Expanding molecule # Example: # if we have num_steps = [3, 2], # batch_size = 2, # step_ids = [0 1 2 0 1] # and num_scaffold_steps=[3, 2], then # molecule: | 0 | 1 | # num_steps: | 3 | 2 | # steps: | 0 | 1 | 2 | 0 | 1 | # rep_ids_rep = [0 0 0 1 1 1 1 2 2 2 2 2 3 3 4 4 4] # indexer = [0 1 2 0 1 2 3 0 1 2 3 4 0 1 0 1 2] (rep_ids_rep, indexer) = rep_and_range(step_ids + num_scaffold_steps[mol_ids]) # Expanding mol_ids # Example: # molecule: | 0 | 1 | # mol_ids: | 0 | 0 | 0 | 1 | 1 | # rep_ids_rep = [0 0 0 1 1 1 1 2 2 2 2 2 3 3 4 4 4] # mol_ids_rep = [0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1] mol_ids_rep = mol_ids[rep_ids_rep] # Expanding and packing mol_array # Example: # if we have # mol_array = a1 a2 a3 a4 a5 # b1 b2 b3 -- -- # where: a1 = [atom_type, begin_ids, end_ids, bond_type is_scaffold] # -- = [-1 -1 -1 -1 -1 ] # then # indexer = [0 1 2 0 1 2 3 0 1 2 3 4 0 1 0 1 2 ] # mol_ids_rep = [0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 ] # mol_array_packed = [a1 a2 a3 a1 a2 a3 a4 a1 a2 a3 a4 a5 b1 b2 b1 b2 b3] # shape: [-1, 4] mol_array_packed = mol_array[mol_ids_rep, indexer, :] # Get the expanded atom type atom_types, is_scaffold = (mol_array_packed[:, I_ATOM_TYPE], mol_array_packed[:, I_IS_SCAFFOLD]) # Get the number of atom at each step: |V_i| # Example: # molecule: | 0 | 1 | # steps: | 0 | 1 | 2 | 0 | 1 | # num_atoms = [ 3 , 4 , 5 , 2 , 3 ] # Note: connect actions should be first filtered # binary vector with int values is_connect = atom_types.ge(0).long() is_connect_index = is_connect.nonzero().squeeze() num_atoms = torch_scatter.scatter_add(is_connect, dim=0, index=rep_ids_rep) atom_types = atom_types[is_connect_index] is_scaffold = is_scaffold[is_connect_index] # Get last_append_mask OLD_ATOM, NEW_APPEND, NEW_CONNECT = range(3) # Locations of the latest appended atoms last_append_loc = torch.cumsum(num_atoms, dim=0) - 1 # Initialize last_append_mask as zeros last_append_mask = torch.full_like(is_scaffold, OLD_ATOM) last_append_mask[last_append_loc] = torch.where( is_scaffold[last_append_loc].eq(1), torch.full_like(last_append_loc, OLD_ATOM), torch.where(num_atoms.gt(pad_first(num_atoms[:-1])), torch.full_like(last_append_loc, NEW_APPEND), torch.full_like(last_append_loc, NEW_CONNECT))) # block_ids, essentially equal to the filtered rep_ids_rep block_ids, atom_ids = rep_and_range(num_atoms) # Get (packed) bond information # size: [-1, 3], where 3=begin_ids, end_ids, bond_type bond_info = mol_array_packed[:, [I_BEGIN_IDS, I_END_IDS, I_BOND_TYPE]] # adjust begin_ids and end_ids for each bond num_atoms_cumsum = pad_first(torch.cumsum(num_atoms, dim=0)[:-1]) I_BOND_BEGIN_IDS = 0 _filter = bond_info[:, I_BOND_BEGIN_IDS].ge(0).nonzero().squeeze() _shift = num_atoms_cumsum[rep_ids_rep] _shift = torch.stack([ _shift, ] * 2 + [torch.zeros_like(_shift)], dim=1) bond_info = (bond_info + _shift)[_filter, :] # symmetrize bond_info bond_info = torch.cat([bond_info, bond_info[:, [1, 0, 2]]], ) # labels for artificial bonds (I_BOND_REMOTE_2, I_BOND_REMOTE_3, I_BOND_ATOM_SELF) = range(ms.num_bond_types, ms.num_bond_types + 3) # artificial bond type: remote connection indices = bond_info[:, :2].cpu().numpy() size = atom_types.size(0) d_indices_2, d_indices_3 = get_remote_connection(indices, size) # pylint: disable=not-callable d_indices_2, d_indices_3 = (torch.tensor(d_indices_2, dtype=torch.long, device=device), torch.tensor(d_indices_3, dtype=torch.long, device=device)) bond_type = torch.full([d_indices_2.size(0), 1], I_BOND_REMOTE_2, dtype=torch.long, device=device) bond_info_remote_2 = torch.cat([d_indices_2, bond_type], dim=-1) bond_type = torch.full([d_indices_3.size(0), 1], I_BOND_REMOTE_3, dtype=torch.long, device=device) bond_info_remote_3 = torch.cat([d_indices_3, bond_type], dim=-1) bond_info = torch.cat([bond_info, bond_info_remote_2, bond_info_remote_3], dim=0) # artificial bond type: self connection begin_ids = end_ids = torch.arange(atom_types.size(0), dtype=torch.long, device=atom_types.device) bond_type = torch.full_like(end_ids, I_BOND_ATOM_SELF) bond_info_self = torch.stack([begin_ids, end_ids, bond_type], dim=-1) bond_info = torch.cat([bond_info, bond_info_self], dim=0) # compile action for each step, which contains the following information # 1. the type of action to carry out: append, connect or terminate # 2. the type of atom to append (0 for connect and termination actions) # 3. the type of bond to connect (0 for termination actions) # 4. the location to append (0 for connect and termination actions) # 5. the location to connect (0 for append and termination actions) # get the batch size batch_size = mol_array.size(0) padding = torch.full([batch_size, 1, 5], -1, dtype=torch.long, device=torch.device('cuda:0')) actions = torch.cat([mol_array, padding], dim=1) actions = actions[mol_ids, step_ids + num_scaffold_steps[mol_ids], :] # 1. THE TYPE OF ACTION PERFORMED AT EACH STEP # 0 for append, 1 for connect, 2 for termination I_MASK, I_APPEND, I_CONNECT, I_END = 0, 0, 1, 2 def _full(_x): """A helper class to create a constant matrix with the same type and length as actions with a given content""" return torch.full([ actions.size(0), ], _x, dtype=torch.long, device=mol_array.device) action_type = torch.where( # if the atom type is defined for step i actions[:, I_ATOM_TYPE].ge(I_MASK), # the the action type is set to 'append' _full(I_APPEND), torch.where( # if the bond type is defined for step i actions[:, I_BOND_TYPE].ge(I_MASK), # then the action is set to 'connect' _full(I_CONNECT), # else 'terminate' _full(I_END))) # 2. THE TYPE OF ATOM ADDED AT EACH 'APPEND' STEP action_atom_type = torch.where(actions[:, I_ATOM_TYPE].ge(I_MASK), actions[:, I_ATOM_TYPE], _full(I_MASK)) # 3. THE BOND TYPE AT EACH STEP action_bond_type = torch.where(actions[:, I_BOND_TYPE].ge(0), actions[:, I_BOND_TYPE], _full(I_MASK)) # 4. THE LOCATION TO APPEND AT EACH STEP append_loc = torch.where(action_type.eq(0), actions[:, I_BEGIN_IDS] + num_atoms_cumsum, _full(I_MASK)) # 5. THE LOCATION TO CONNECT AT EACH STEP connect_loc = torch.where(action_type.eq(1), actions[:, I_END_IDS] + num_atoms_cumsum, _full(I_MASK)) # Stack everything together # size: [-1, 5] # 5 = action_type, atom_type, bond_type, append_loc, connect_loc actions = torch.stack([ action_type, action_atom_type, action_bond_type, append_loc, connect_loc ], dim=-1) return ( # 1. structure information: # atom (node) type and bond (edge) information atom_types, is_scaffold, bond_info, last_append_mask, # 2. action to carry out at each step actions, # 3. indices mol_ids, step_ids, block_ids, atom_ids)
from itertools import chain import numpy as np import networkx as nx from rdkit import Chem from rdkit.Chem import rdmolops from scipy import sparse from mol_spec import MoleculeSpec __all__ = [ # "smiles_to_nx_graph", "WeaveMol", ] ms = MoleculeSpec.get_default() class WeaveMol(object): def __init__(self, smiles: str): self.smiles = smiles self.mol = Chem.MolFromSmiles(smiles) self.ms = ms self.num_atom_types = ms.num_atom_types self.num_bond_types = ms.num_bond_types self.num_atoms = self.mol.GetNumAtoms() self.original_atoms = list(range(self.num_atoms)) self.num_original_bonds = self.mol.GetNumBonds() self._atom_types = None self._original_bond_info = None self._original_bond_info_np = None
def get_data_loader(scaffold_network_loc, molecule_smiles_loc, exclude_ids_loc, split_type, batch_size, batch_size_test, num_iterations, num_workers, k, p, ms=MoleculeSpec.get_default()): """Helper function for getting the dataloader Args: scaffold_network_loc (str): The location to the network file molecule_smiles_loc (str): The location to the file containing molecular SMILES exclude_ids_loc (str): File storing the indices of which molecule/scaffold should be excluded split_type (str): The type of split, should be 'scaffold' or 'molecule' batch_size (int): Batch size for training batch_size_test (int): Batch size for test num_iterations (int): The number of iterations to train num_workers (int): The number of workers used for data loading k (int) p (float) ms (MoleculeSpec, optional) Returns: t.Tuple[DataLoader, DataLoader]: The loader for training data and test data """ # Get dataset db = ScaffoldMolDataset(scaffold_network_loc, molecule_smiles_loc) assert batch_size % 2 == 0 and batch_size_test % 2 == 0 sampler_train = ScaffoldMolSampler(dataset=db, batch_size=(batch_size // 2, batch_size // 2), num_iterations=num_iterations, exclude_ids_loc=exclude_ids_loc, training=True, split_type=split_type) sampler_test = ScaffoldMolSampler(dataset=db, batch_size=(batch_size_test // 2, batch_size_test // 2), num_iterations=num_iterations, exclude_ids_loc=exclude_ids_loc, training=False, split_type=split_type) # Get DataLoaders loader_train = DataLoader(dataset=db, sampler=sampler_train, k=k, p=p, num_workers=num_workers, ms=ms) loader_test = DataLoader(dataset=db, sampler=sampler_test, k=k, p=p, num_workers=0, ms=ms) return loader_train, loader_test
def engine(ckpt_loc='ckpt/ckpt-default', molecule_loc='data_utils/molecules.smi', network_loc='data_utils/scaffolds_molecules.pkl.gz', exclude_ids_loc='ckpt/ckpt-default/exclude_ids.txt', full=False, split_by='molecule', training_only=False, num_workers=2, num_atom_embedding=16, causal_hidden_sizes=(32, 64), num_bn_features=96, num_k_features=24, num_layers=20, num_output_features=256, efficient=False, ms=MoleculeSpec.get_default(), activation='elu', lr=1e-3, decay=0.01, decay_step=100, min_lr=5e-5, summary_step=200, clip_grad=3.0, batch_size=128, batch_size_test=256, num_iterations=50000, k=5, p=0.5, gpu_ids=(0, 1, 2, 3)): """Engine for training scaffold based VAE Args: ckpt_loc (str): Location to store model checkpoints molecule_loc (str): Location of molecule SMILES strings network_loc (str): Location of the bipartite network exclude_ids_loc (str): The location storing the ids to be excluded from the training set full (bool): Whether to use the full dataset for training, default to False split_by (str): Whether to split by scaffold or molecule training_only (str): Recording only training loss, default to False num_workers (int): Number of workers used during data loading, default to 1 num_atom_embedding (int): The size of the initial node embedding causal_hidden_sizes (tuple[int] or list[int]): The size of hidden layers in causal weave blocks num_bn_features (int): The number of features used in bottleneck layers in each dense layer num_k_features (int): The growth rate of dense net num_layers (int): The number of densenet layers num_output_features (int): The number of output features for the densenet efficient (bool): Whether to use the memory efficient implementation of densenet ms (mol_spec.MoleculeSpec) activation (str): The activation function used, default to 'elu' lr (float): (Initial) learning rate decay (float): The rate of learning rate decay decay_step (int): The interval of each learning rate decay min_lr (float): The minimum learning rate summary_step (int): Interval of summary clip_grad (float): Gradient clipping batch_size (int): The batch size for training batch_size_test (int): The batch size for testing num_iterations (int): The number of total iterations for model training k (int): The number of importance samples p (float): The degree of stochasticity of importance sampling 0.0 for fully stochastic decoding, 1.0 for fully deterministic decoding gpu_ids (tuple[int] or list[int]): Which GPUs are used for training """ # ANCHOR Check whether to continue training is_continuous = _check_continuous(ckpt_loc) # ANCHOR Create iterators for training and test dataset loader_train, loader_test = _get_loader(network_loc=network_loc, molecule_loc=molecule_loc, exclude_ids_loc=exclude_ids_loc, split_by=split_by, batch_size=batch_size, batch_size_test=batch_size_test, num_iterations=num_iterations, num_workers=num_workers, full=full, training_only=training_only, k=k, p=p, ms=ms) iter_train = iter(loader_train) iter_test = iter(loader_test) if loader_test is not None else None # ANCHOR Initialize model with random params mdl = _init_mdl(num_atom_embedding=num_atom_embedding, causal_hidden_sizes=causal_hidden_sizes, num_bn_features=num_bn_features, num_k_features=num_k_features, num_layers=num_layers, num_output_features=num_output_features, efficient=efficient, activation=activation, gpu_ids=gpu_ids) # ANCHOR Initialize optimizer and scheduler optimizer = optim.Adam(mdl.parameters(), lr=lr) scheduler = optim.lr_scheduler.StepLR(optimizer, decay_step, 1.0 - decay) # ANCHOR Load previously stored states if is_continuous: # restore states mdl, optimizer, scheduler, t0, global_counter = _restore( mdl=mdl, optimizer=optimizer, scheduler=scheduler, ckpt_loc=ckpt_loc) else: t0 = time.time() global_counter = 0 device = torch.device(f'cuda:{gpu_ids[0]}') with open(os.path.join(ckpt_loc, 'log.out'), mode='a' if is_continuous else 'w') as f: if not is_continuous: f.write('global_step\ttime(min)\tloss\tlr\n') try: while True: global_counter += 1 # Update global counter # Perform one-step of training loss = _train_step(mdl=mdl, optimizer=optimizer, scheduler=scheduler, min_lr=min_lr, clip_grad=clip_grad, device=device, iter_train=iter_train) if global_counter % summary_step == 0: if not training_only: try: loss = _test_step(mdl, device, iter_test) except StopIteration: iter_test = iter(loader_test) loss = _test_step(mdl, device, iter_test) loss = loss.item() # Get learning rate current_lr = [ params_group['lr'] for params_group in optimizer.param_groups ][0] # Save status message_str = _save(mdl=mdl, optimizer=optimizer, scheduler=scheduler, global_counter=global_counter, t0=t0, loss=loss, current_lr=current_lr, ckpt_loc=ckpt_loc) f.write(message_str) f.flush() except StopIteration: if not training_only: try: loss = _test_step(mdl, device, iter_test) except StopIteration: iter_test = iter(loader_test) loss = _test_step(mdl, device, iter_test) loss = loss.item() # Get learning rate current_lr = [ params_group['lr'] for params_group in optimizer.param_groups ][0] # Save status message_str = _save(mdl=mdl, optimizer=optimizer, scheduler=scheduler, global_counter=global_counter, t0=t0, loss=loss, current_lr=current_lr, ckpt_loc=ckpt_loc) f.write(message_str) f.flush() f.write('Training finished')
def __init__(self, num_atom_embedding: int, causal_hidden_sizes: int, num_bn_features: int, num_k_features: int, num_layers: int, num_output_features: int, efficient: bool = False, ms: MoleculeSpec = MoleculeSpec.get_default(), activation: str = 'elu', conditional: bool = False, num_cond_features: t.Optional[int] = None, activation_cond: t.Optional[str] = None): """ The constructor Args: num_atom_embedding (int): The size of the initial node embedding causal_hidden_sizes (tuple[int]): The size of hidden layers in causal weave blocks num_bn_features (int): The number of features used in bottleneck layers in each dense layer num_k_features (int): The growth rate of dense net num_layers (int): The number of densenet layers num_output_features (int): The number of output features for the densenet efficient (bool): Whether to use the memory efficient BNReLULinearimplementation of densenet ms (mol_spec.MoleculeSpec) activation (str): The activation function used, default to 'elu' conditional (bool): Whether to include conditional input, default to False num_cond_features (int or None): The size of conditional input, should be None if self.conditional is False activation_cond (str or None): Activation function used for conditional input should be None if self.conditional is False """ super(DeepScaffold, self).__init__() self.num_atom_embedding = num_atom_embedding self.causal_hidden_sizes = causal_hidden_sizes self.num_bn_features = num_bn_features self.num_k_features = num_k_features self.num_layers = num_layers self.num_output_features = num_output_features self.efficient = efficient # pylint: disable=invalid-name self.ms = ms # 3 = 2 * remote connection + self connection self._num_bond_types = self.ms.num_bond_types + 3 self._num_atom_types = self.ms.num_atom_types self.activation = activation self.conditional = conditional self.num_cond_features = num_cond_features self.activation_cond = activation_cond # embedding layer for atom types and bond types # 3 = is_scaffold + new_append + new_connect self.atom_embedding = nn.Embedding((self._num_atom_types + self.ms.num_atom_types * 3), self.num_atom_embedding) # convolution layer self.mol_conv = DenseNet(self.num_atom_embedding, self._num_bond_types, self.causal_hidden_sizes, self.num_bn_features, self.num_k_features, self.num_layers, self.num_output_features, self.efficient, self.activation, self.conditional, self.num_cond_features, self.activation_cond) # Pooling layer self.avg_pool = AvgPooling(self.num_output_features, self.activation) # output layers self.end = BNReLULinear(self.num_output_features, 1, self.activation) self.append_connect = \ BNReLULinear(self.num_output_features * 2, ms.num_atom_types * ms.num_bond_types + ms.num_bond_types, self.activation)
def get_mol_from_array(mol_array, sanitize=True, ms=MoleculeSpec.get_default()): """Converting molecule array to Chem.Mol objects Args: mol_array (np.ndarray): The array representation of molecules dtype: int, shape: [num_samples, num_steps, 5] sanitize (bool): Whether to sanitize the output molecule, default to True ms (mol_spec.MoleculeSpec) Returns: list[Chem.Mol]: mol_list - The list of output molecules """ # shape: num_samples, num_steps is_scaffold = mol_array[:, :, -1] # shape: num_samples, num_steps, 5 mol_array = mol_array[:, :, :-1] # get shape information num_samples, max_num_steps, _ = mol_array.shape # initialize the list of output molecules mol_list = [] # loop over molecules for mol_id in range(num_samples): try: mol = Chem.RWMol(Chem.Mol()) # initialize molecule atom_list = [] # List to store all created atoms scaffold_atoms = [] # List to store all scaffold atoms aromatic_atoms = [] # List to store all aromatic atoms n_atoms = [] # The indices of all nitrogen atoms for step_id in range(max_num_steps): atom_type, begin_ids, end_ids, bond_type = mol_array[mol_id, step_id, :].tolist() if end_ids == -1: # if the actions is to terminate break elif begin_ids == -1: # if the action is to initialize new_atom = ms.index_to_atom(atom_type) mol.AddAtom(new_atom) atom_list.append(new_atom) if new_atom.GetSymbol() == 'N': n_atoms.append(end_ids) elif atom_type == -1: # if the action is to connect ms.index_to_bond(mol, begin_ids, end_ids, bond_type) else: # if the action is to append new atom new_atom = ms.index_to_atom(atom_type) mol.AddAtom(new_atom) ms.index_to_bond(mol, begin_ids, end_ids, bond_type) # Record atom atom_list.append(new_atom) if is_scaffold[mol_id, step_id]: # Both ends are scaffold atoms scaffold_atoms.append(end_ids) scaffold_atoms.append(begin_ids) if bond_type == ms.bond_orders.index(Chem.BondType.AROMATIC): aromatic_atoms.append(begin_ids) aromatic_atoms.append(end_ids) if new_atom.GetSymbol() == 'N': n_atoms.append(end_ids) special_atoms = (set(scaffold_atoms) & set(aromatic_atoms) & set(n_atoms)) scaffold_atoms = set(scaffold_atoms) for atom_id in special_atoms: neighbors = (mol_array[mol_id, mol_array[mol_id, :, 2] == atom_id, 1].tolist() + mol_array[mol_id, mol_array[mol_id, :, 1] == atom_id, 2].tolist()) neighbors = set(neighbors) - {-1} if neighbors - scaffold_atoms: atom_i = mol.GetAtomWithIdx(atom_id) if atom_i.GetNumExplicitHs() > 0: num_explict_hs = atom_i.GetNumExplicitHs() - 1 atom_i.SetNumExplicitHs(num_explict_hs) else: num_formal_charge = atom_i.GetFormalCharge() + 1 atom_i.SetFormalCharge(num_formal_charge) if sanitize: mol = mol.GetMol() Chem.SanitizeMol(mol) except (ValueError, RuntimeError): mol = None mol_list.append(mol) return mol_list
def get_array_from_mol(mol, scaffold_nodes, nh_nodes, np_nodes, k, p, ms=MoleculeSpec.get_default()): """Represent the molecule using `np.ndarray` Args: mol (Chem.Mol): The input molecule scaffold_nodes (Iterable): The location of scaffold represented as `list`/`np.ndarray` nh_nodes (Iterable): Nodes with modifications np_nodes (Iterable): Nodes with modifications k (int): The number of importance samples p (float): Degree of uncertainty during route sampling, should be in (0, 1) ms (mol_spec.MoleculeSpec) Returns: mol_array (np.ndarray): The numpy representation of the molecule dtype - np.int32, shape - [k, num_bonds + 1, 5] logp (np.ndarray): The log-likelihood of each route dtype - np.float32, shape - [k, ] """ atom_types = [] bond_info = [] num_bonds = mol.GetNumBonds() # sample route scaffold_nodes = np.array(list(scaffold_nodes), dtype=np.int32) route_list, step_ids_list, logp = _sample_ordering(mol=mol, scaffold_nodes=scaffold_nodes, k=k, p=p) for atom_id, atom in enumerate(mol.GetAtoms()): if atom_id in nh_nodes: atom.SetNumExplicitHs(atom.GetNumExplicitHs() + 1) if atom_id in np_nodes: atom.SetFormalCharge(atom.GetFormalCharge() - 1) atom_types.append(ms.get_atom_type(atom)) for bond in mol.GetBonds(): bond_info.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), ms.get_bond_type(bond)]) # shape: # atom_types: num_atoms # bond_info: num_bonds x 3 atom_types = np.array(atom_types, dtype=np.int32) bond_info = np.array(bond_info, dtype=np.int32) # initialize packed molecule array data mol_array = [] for sample_id in range(k): # get the route and step_ids for the i-th sample route_i = route_list[sample_id, :] step_ids_i = step_ids_list[sample_id, :] # reorder atom types and bond info # note: bond_info [start_ids, end_ids, bond_type] atom_types_i, bond_info_i, is_append = _reorder(atom_types=atom_types, bond_info=bond_info, route=route_i, step_ids=step_ids_i) # atom type added at each step # -1 if the current step is connect atom_types_added = np.full([num_bonds, ], -1, dtype=np.int32) atom_types_added[is_append] = atom_types_i[bond_info_i[:, 1]][is_append] # pack into mol_array_i # size: num_bonds x 4 # note: [atom_types_added, start_ids, end_ids, bond_type] mol_array_i = np.concatenate([atom_types_added[:, np.newaxis], bond_info_i], axis=-1) # add initialization step init_step = np.array([[atom_types_i[0], -1, 0, -1]], dtype=np.int32) # concat into mol_array # size: (num_bonds + 1) x 4 mol_array_i = np.concatenate([init_step, mol_array_i], axis=0) # Mark up scaffold bonds is_scaffold = np.logical_and(mol_array_i[:, 1] < len(scaffold_nodes), mol_array_i[:, 2] < len(scaffold_nodes)) is_scaffold = is_scaffold.astype(np.int32) # Concatenate # shape: k x (num_bonds + 1) x 5 mol_array_i = np.concatenate((mol_array_i, is_scaffold[:, np.newaxis]), axis=-1) mol_array.append(mol_array_i) # num_samples x (num_bonds + 1) x 4 mol_array = np.stack(mol_array, axis=0) # Output size: # mol_array: k x (num_bonds + 1) x 4 # logp: k return mol_array, logp
def get_data_loader(scaffold_network_loc: str, molecule_smiles_loc: str, exclude_ids_loc: str, split_type: str, batch_size: int, batch_size_test: int, num_iterations: int, num_workers: int, k: int, p: float, ms: MoleculeSpec = MoleculeSpec.get_default() ) -> t.Tuple[DataLoader, DataLoader]: """Helper function for getting the dataloader Args: scaffold_network_loc (str): The location to the network file molecule_smiles_loc (str): The location to the file containing molecular SMILES exclude_ids_loc (str): File storing the indices of which molecule/scaffold should be excluded split_type (str): The type of split, should be 'scaffold' or 'molecule' batch_size (int): Batch size for training batch_size_test (int): Batch size for test num_iterations (int): The number of iterations to train num_workers (int): The number of workers used for data loading k (int) p (float) ms (MoleculeSpec, optional) Returns: t.Tuple[DataLoader, DataLoader]: The loader for training data and test data """ # Get dataset # pylint: disable=invalid-name db = ScaffoldMolDataset(scaffold_network_loc, molecule_smiles_loc) assert batch_size % 2 == 0 and batch_size_test % 2 == 0 (sampler_train, sampler_test) = (ScaffoldMolSampler(db, (batch_size // 2, batch_size // 2), num_iterations, exclude_ids_loc, True, split_type), ScaffoldMolSampler(db, (batch_size_test // 2, batch_size_test // 2), num_iterations, exclude_ids_loc, False, split_type)) # Get DataLoaders loader_train, loader_test = \ (DataLoader(db, sampler_train, k, p, num_workers, ms), DataLoader(db, sampler_test, k, p, 0, ms)) return loader_train, loader_test