def build_model(ckpt_loc: str) -> DeepScaffold: """ Building model from checkpoint file """ # Configure default parameters config = { "num_atom_embedding": 16, "causal_hidden_sizes": (32, 64), "num_bn_features": 96, "num_k_features": 24, "num_layers": 20, "num_output_features": 256, "efficient": False, "activation": 'elu' } # Local configuration file # pylint: disable=invalid-name with open(os.path.join(ckpt_loc, 'config.json')) as f: config_update = json.load(f) # Update default configuration with config for key in config_update: if key in config: config[key] = config_update[key] # Build model mdl = DeepScaffold(**config) # Load checkpoint mdl = nn.DataParallel(mdl) mdl.load_state_dict(torch.load(os.path.join(ckpt_loc, 'mdl.ckpt'))) # Unwrap from nn.DataParallel, move to GPU mdl = mdl.module.cuda(0).eval() return mdl
def _init_mdl(num_atom_embedding, causal_hidden_sizes, num_bn_features, num_k_features, num_layers, num_output_features, efficient, activation, gpu_ids): """Helper function for initializing model Args: num_atom_embedding (int): The size of the initial node embedding causal_hidden_sizes (tuple[int] or list[int]): The size of hidden layers in causal weave blocks num_bn_features (int): The number of features used in bottleneck layers in each dense layer num_k_features (int): The growth rate of dense net num_layers (int): The number of densenet layers num_output_features (int): The number of output features for the densenet efficient (bool): Whether to use the memory efficient implementation of densenet Returns: nn.DataParallel: The model intialized """ # Create empty model with config configs = { 'num_atom_embedding': num_atom_embedding, 'causal_hidden_sizes': causal_hidden_sizes, 'num_bn_features': num_bn_features, 'num_k_features': num_k_features, 'num_layers': num_layers, 'num_output_features': num_output_features, 'efficient': efficient, 'activation': activation, } mdl = DeepScaffold(**configs) # Weight initializer def init_weights(m): if isinstance(m, nn.Linear): nn.init.xavier_normal(m.weight) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm1d): m.weight.data.fill_(1.0) mdl.apply(init_weights) # Wrap into nn.DataParallel and move to gpu mdl = nn.DataParallel(mdl, device_ids=gpu_ids) mdl.cuda() return mdl
def sample( mdl: DeepScaffold, scaffold_smi: str, num_samples: int) -> t.Tuple[t.List[t.Union[str, None]], float, float]: """ Generate `num_samples` samples from the model `mdl` based on a given scaffold with SMILES `scaffold_smi`. Args: mdl (DeepScaffold): The scaffold-based molecule generative model scaffold_smi (str): The SMILES string of the given scaffold num_samples (int): The number of samples to generate Returns: t.Tuple[t.List[t.Union[str, None]], float, float]: The generated molecules. Molecules that does not satisfy the validity requirements are returned as `None` """ # pylint: disable=invalid-name lg = RDLogger.logger() lg.setLevel(RDLogger.CRITICAL) # Convert SMILES to molecule scaffold = Chem.MolFromSmiles(scaffold_smi) # Convert molecule to numpy array # shape: 1, ..., 5 scaffold_array: np.ndarray scaffold_array, _ = \ get_array_from_mol(mol=scaffold, scaffold_nodes=range(scaffold.GetNumHeavyAtoms()), nh_nodes=[], np_nodes=[], k=1, p=1.0) # Convert numpy array to torch tensor # shape: 1, ..., 5 scaffold_tensor: torch.Tensor scaffold_tensor = torch.from_numpy(scaffold_array).long().cuda() # Generate with torch.no_grad(): # Expand the first dimension # shape: num_samples, ..., 5 scaffold_tensor = scaffold_tensor.expand(num_samples, -1, -1) # Generate samples # shape: [num_samples, -1, 5] mol_array = mdl.generate(scaffold_tensor) # Move to CPU mol_array = mol_array.detach().cpu().numpy() # Convert numpy array to Chem.Mol object mol_list: t.List[t.Union[None, Chem.Mol]] mol_list = get_mol_from_array(mol_array, sanitize=True) # Convert Chem.Mol object to SMILES def _to_smiles(_mol): if _mol is None: return None try: _smiles = Chem.MolToSmiles(_mol) except ValueError: # If the molecule can not be converted to SMILES, return None return None # If the output SMILES is None, return None if _smiles is None: return None # Make sure that the SMILES can be convert back to molecule try: _mol = Chem.MolFromSmiles(_smiles) except ValueError: # If there are any error encountered during the process, # return None return None # If the output molecule object is None, return None if _mol is None: return None return _smiles smiles_list = list(map(_to_smiles, mol_list)) # Get the validity statistics num_valid = sum(1 for _ in smiles_list if _ is not None) percent_valid = float(num_valid) / len(smiles_list) # Get the uniqueness statistics num_unique = len(set(smiles_list)) - 1 percent_unique = float(num_unique) / num_valid return smiles_list, percent_valid, percent_unique