Пример #1
0
def con_draw_mols(total_optim_dict, dir_name, sim=3):
    all_keys = list(total_optim_dict.keys())
    cnt = 0
    for key in all_keys:
        if len(total_optim_dict[key][sim][0]) > 0:
            gen_dir = './saved_images_{}/{}'.format(dir_name, cnt)
            os.makedirs(gen_dir, exist_ok=True)
            cnt += 1
            t_flag = 0

            #             for i in range(len(total_optim_dict[key][0][0])):
            #                 mol = Chem.MolFromSmiles(total_optim_dict[key][0][0][i])
            #                 if qed(mol) < 0.8: continue
            #                 filepath = os.path.join(gen_dir, 'sim_{}_imp_{}_qed_{}_{}.png'.format(total_optim_dict[key][0][2][i],total_optim_dict[key][0][1][i],qed(mol),total_optim_dict[key][0][0][i]))
            #                 img = Draw.MolToImage(mol)
            #                 img.save(filepath)
            for i in range(len(total_optim_dict[key][1][0])):
                mol = Chem.MolFromSmiles(total_optim_dict[key][1][0][i])
                if total_optim_dict[key][1][1][i] < 2: continue
                filepath = os.path.join(
                    gen_dir, 'sim_{}_imp_{}_qed_{}_{}.png'.format(
                        total_optim_dict[key][1][2][i],
                        total_optim_dict[key][1][1][i], qed(mol),
                        total_optim_dict[key][1][0][i]))
                img = Draw.MolToImage(mol)
                img.save(filepath)
                t_flag = 1
            for i in range(len(total_optim_dict[key][2][0])):
                mol = Chem.MolFromSmiles(total_optim_dict[key][2][0][i])
                if total_optim_dict[key][2][1][i] < 2: continue
                filepath = os.path.join(
                    gen_dir, 'sim_{}_imp_{}_qed_{}_{}.png.png'.format(
                        total_optim_dict[key][2][2][i],
                        total_optim_dict[key][2][1][i], qed(mol),
                        total_optim_dict[key][2][0][i]))
                img = Draw.MolToImage(mol)
                img.save(filepath)
                t_flag = 1
            for i in range(len(total_optim_dict[key][3][0])):
                mol = Chem.MolFromSmiles(total_optim_dict[key][3][0][i])
                if total_optim_dict[key][3][1][i] < 2: continue
                filepath = os.path.join(
                    gen_dir, 'sim_{}_imp_{}_qed_{}_{}.png.png'.format(
                        total_optim_dict[key][3][2][i],
                        total_optim_dict[key][3][1][i], qed(mol),
                        total_optim_dict[key][3][0][i]))
                img = Draw.MolToImage(mol)
                img.save(filepath)
                t_flag = 1
            if t_flag == 1:
                filepath = os.path.join(gen_dir, 'original_{}.png'.format(key))
                mol = Chem.MolFromSmiles(key)
                img = Draw.MolToImage(mol)
                img.save(filepath)
    return
Пример #2
0
    def evaluate(self, point: Any) -> float:
        """
        Evaluate a point.

        Args:
            point: point to evaluate.

        Returns:
            evaluation for the given point.
        """
        latent_point = torch.tensor([[point]])
        batch_latent = latent_point.repeat(1, self.batch, 1)

        smiles = self.generator.generate_smiles(batch_latent)

        qed_values = []
        for smile in smiles:
            try:
                qed_values.append(qed(Chem.MolFromSmiles(smile)))
            except Exception:
                qed_values.append(0)
                logger.warning("QED calculation failed.")

        if len(qed_values) > 0:
            return 1.0 - (sum(qed_values) / len(qed_values))
        else:
            return 1.0
 def qed_evaluate(self, valid_smiles):
     qed_lst = []
     for i in valid_smiles:
         mol = Chem.MolFromSmiles(i)
         qed_score = qed(mol)
         qed_lst.append(qed_score)
     return qed_lst
 def filtered_qed(self, valid_smiles):
     count = 0
     for i in valid_smiles:
         mol = Chem.MolFromSmiles(i)
         qed_score = qed(mol)
         if qed_score < 0.4:
             valid_smiles.remove(i)
             count = count + 1
     print("unavaliable QED mol:%i" % count)
     return valid_smiles
Пример #5
0
def reward_target_qed(mol, target, ratio=0.1, max=4):
    """
    Reward for a target log p
    :param mol: rdkit mol object
    :param target: float
    :return: float (-inf, max]
    """
    x = qed(mol)
    reward = -1 * np.abs((x - target) / ratio) + max
    return reward
Пример #6
0
def test_mol_score():
    mol_smiles = 'COC1=CC=C(C2=CC(C3=CC=CC=C3)=CC(C3=CC=C(Br)C=C3)=[O+]2)C=C1'
        #'CC(C)CCN1N=C(C(=O)NC2(CCCCCCC3=CC=CC=C3)CCCCC2)C=CC1=O'
        #'C2=CN=C(C=C2)CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC'  # 'CCCCCCCCCCNCCCC(CCC)CCCCCCCCCCCCCCCCC' #''CCCc1ccccc1C=CC=CCNNC(=O)CCc1ccc(OC)c(C)c1'
    print(len(mol_smiles))
    mol = Chem.MolFromSmiles(mol_smiles)

    plogp = penalized_logp(mol)
    q = qed(mol)
    print(mol_smiles)
    print(plogp, q)
Пример #7
0
def qed_score(mol):
    """
    Quantitative Drug Likeness (QED)
    :param mol: input mol
    :return: score
    """
    try:
        score = qed(mol)
    except:
        score = 0
    return score
Пример #8
0
def compute_cost_qed(G, writefile="temp.txt"):
    qed_val = 2.0
    if guess_correct_molecules_from_graph(G, writefile):
        m1 = Chem.MolFromMol2File(writefile)
        if m1 != None:
            qed_val = 1.0 - qed(m1)
        else:
            print "Error: None"
    else:
        print "Error: wrong molecule"

    return qed_val
Пример #9
0
    def __call__(self, mol):
        """
        Returns the QED of a SMILES string or a RdKit molecule.
        """

        # Error handling.
        if type(mol) == rdkit.Chem.rdchem.Mol:
            pass
        elif type(mol) == str:
            mol = Chem.MolFromSmiles(mol, sanitize=False)
            if mol is None:
                raise ValueError("Invalid SMILES string.")
        else:
            raise TypeError("Input must be from {str, rdkit.Chem.rdchem.Mol}")

        return qed(mol)
Пример #10
0
 def generate_fingerprints_and_create_list(self):
     #generate fingerprints of predicted ligands and known ligands:
     gen_mo = rdFingerprintGenerator.GetMorganGenerator(fpSize=2048,
                                                        radius=2)
     predicted_fps = [
         gen_mo.GetFingerprint(mol) for mol in self.predicted['molecules']
     ]
     true_fps = [
         gen_mo.GetFingerprint(mol) for mol in self.true_pos['molecules']
     ]
     similarities = list()
     for count, mol in enumerate(predicted_fps):
         tanimoto_values = ([
             DataStructs.TanimotoSimilarity(mol, i) for i in true_fps
         ])
         index_of_highest = np.argmax(tanimoto_values)
         similarities.append(tanimoto_values[index_of_highest])
     #module code is in: https://github.com/rdkit/rdkit/tree/master/Contrib/SA_Score
     sa_score = [
         sascorer.calculateScore(i)
         for i in list(self.predicted['molecules'])
     ]
     #create a list holding the QED drug-likeness score
     #reference: https://doi.org/10.1038/nchem.1243
     qeds = [qed(mol) for mol in self.predicted['molecules']]
     #create a list holding logp:
     logp = [Descriptors.MolLogP(m) for m in self.predicted['molecules']]
     #filter catalog usage instructions are here: https://github.com/rdkit/rdkit/pull/536
     params = FilterCatalogParams()
     params.AddCatalog(FilterCatalogParams.FilterCatalogs.BRENK)
     catalog = FilterCatalog(params)
     self.brenk = np.array(
         [catalog.HasMatch(m) for m in self.predicted['molecules']])
     #add these lists as columns to the 'predicted' pd.DataFrame
     self.predicted['similarities'] = similarities
     self.predicted['sa_score'] = sa_score
     self.predicted['qeds'] = qeds
     self.predicted['logp'] = logp
     print(self.predicted['logp'] < 6)
     shortlist_mask = ((self.predicted['similarities'] < 0.2) &
                       (self.predicted['sa_score'] < 4) &
                       (self.predicted['qeds'] > 0.25) &
                       (self.predicted['logp'] < 6) & (~self.brenk))
Пример #11
0
    def __call__(self, mol):
        """
        Returns the QED of a SMILES string or a RdKit molecule.
        """

        # Error handling.
        if type(mol) == rdkit.Chem.rdchem.Mol:
            pass
        elif type(mol) == str:
            mol = Chem.MolFromSmiles(mol, sanitize=False)
            if mol is None:
                raise ValueError("Invalid SMILES string.")
        else:
            raise TypeError("Input must be from {str, rdkit.Chem.rdchem.Mol}")

        try:
            return qed(mol)
        # Catch atom valence exception raised by CalcCrippenDescriptor
        except Exception:
            return 0.
Пример #12
0
    t_start = time.time()
    gen_x, gen_adj = generate(model, n_atom, n_atom_type, n_edge_type, device)

    gen_mols = gen_mol(gen_adj, gen_x, atomic_num_list, correct_validity=True)
    gen_results = metric_random_generation(gen_mols)

    t_end = time.time()

    gen_time.append(t_end - t_start)

    valid_ratio.append(gen_results['valid_ratio'])

    valid_mols = [mol for mol in gen_mols if check_chemical_validity(mol)]

    if args.property_name == 'qed':
        prop_scores = [qed(mol) for mol in valid_mols]
    elif args.property_name == 'plogp':
        prop_scores = [calculate_min_plogp(mol) for mol in valid_mols]

    inds = sorted(range(len(prop_scores)),
                  key=lambda k: prop_scores[k],
                  reverse=True)
    gen_smiles = [
        Chem.MolToSmiles(mol, isomericSmiles=False) for mol in gen_mols
    ]
    sorted_smiles = [gen_smiles[i] for i in inds]
    sorted_scores = [prop_scores[i] for i in inds]

    if args.save_result_file is not None:
        file = args.save_result_file
        with open(file, 'a+') as f:
Пример #13
0
def get_reward_from_mol(mol):
  if mol not in REWARDS:
    REWARDS[mol] = qed(mol)
  return REWARDS[mol]
Пример #14
0
    # Predict with each model individually and sum predictions
    if args.dataset_type == 'multiclass':
        sum_preds = np.zeros(
            (len(test_data), args.num_tasks, args.multiclass_num_classes))
    else:
        sum_preds = np.zeros((len(test_data), args.num_tasks))

    model = load_checkpoint(checkpoint_path, cuda=args.cuda)
    model_preds = predict(model=model,
                          data=test_data,
                          batch_size=1,
                          scaler=scaler)
    sum_preds += np.array(model_preds)

    # Ensemble predictions
    return sum_preds[0][0]


#Debug

# print(predict_smile("../../dopamine_test/fold_0/model_2/model.pt","CC1=CC(=C1Cl)NC2=CC(=C2Cl)Cl"))
if __name__ == "__main__":
    mol = Chem.MolFromSmiles('CCCCOC1=CC1=C(NC(N)=O)C1=CC=C1O')
    sa = -1 * calculateScore(mol)
    print("SA score", (sa + 10) / (10 - 1))
    print("QED", qed(mol))
    print(
        "pki: ",
        predict_smile("../../model_hyperopt.pt",
                      "CCCCOC1=CC1=C(NC(N)=O)C1=CC=C1O"))
Пример #15
0
 def get_qed(self):
     self.qed = qed(self._mol) 
     return self.qed
Пример #16
0
def get_reward_from_mol(mol):
    """Returns the reward."""
    return qed(mol)
 def add_qed_score(self):
     """create a list holding the QED drug-likeness score
     reference: https://doi.org/10.1038/nchem.1243"""
     qeds = [qed(mol) for mol in self.df['mols']]
     self.df['qed_score'] = qeds
     print(f'QED score range: {min(qed)} -  {max(qed)}')
Пример #18
0
def main(parser_namespace):
    # model loading
    disable_rdkit_logging()
    affinity_path = parser_namespace.affinity_path
    svae_path = parser_namespace.svae_path
    svae_weights_path = os.path.join(svae_path, "weights", "best_rec.pt")
    results_file_name = parser_namespace.optimisation_name

    logger.add(results_file_name + ".log", rotation="10 MB")

    svae_params = dict()
    with open(os.path.join(svae_path, "model_params.json"), "r") as f:
        svae_params.update(json.load(f))

    smiles_language = SMILESLanguage.load(
        os.path.join(svae_path, "selfies_language.pkl"))

    # initialize encoder, decoder, testVAE, and GP_generator_MW
    gru_encoder = StackGRUEncoder(svae_params)
    gru_decoder = StackGRUDecoder(svae_params)
    gru_vae = TeacherVAE(gru_encoder, gru_decoder)
    gru_vae.load_state_dict(
        torch.load(svae_weights_path, map_location=get_device()))

    gru_vae._associate_language(smiles_language)
    gru_vae.eval()

    smiles_generator = SmilesGenerator(gru_vae)

    with open(os.path.join(affinity_path, "model_params.json")) as f:
        predictor_params = json.load(f)
    affinity_predictor = MODEL_FACTORY["bimodal_mca"](predictor_params)
    affinity_predictor.load(
        os.path.join(
            affinity_path,
            f"weights/best_{predictor_params.get('p_metric', 'ROC-AUC')}_bimodal_mca.pt",
        ),
        map_location=get_device(),
    )
    affinity_protein_language = ProteinLanguage.load(
        os.path.join(affinity_path, "protein_language.pkl"))
    affinity_smiles_language = SMILESLanguage.load(
        os.path.join(affinity_path, "smiles_language.pkl"))
    affinity_predictor._associate_language(affinity_smiles_language)
    affinity_predictor._associate_language(affinity_protein_language)
    affinity_predictor.eval()

    erg_protein = "MASTIKEALSVVSEDQSLFECAYGTPHLAKTEMTASSSSDYGQTSKMSPRVPQQDWLSQPPARVTIKMECNPSQVNGSRNSPDECSVAKGGKMVGSPDTVGMNYGSYMEEKHMPPPNMTTNERRVIVPADPTLWSTDHVRQWLEWAVKEYGLPDVNILLFQNIDGKELCKMTKDDFQRLTPSYNADILLSHLHYLRETPLPHLTSDDVDKALQNSPRLMHARNTGGAAFIFPNTSVYPEATQRITTRPDLPYEPPRRSAWTGHGHPTPQSKAAQPSPSTVPKTEDQRPQLDPYQILGPTSSRLANPGSGQIQLWQFLLELLSDSSNSSCITWEGTNGEFKMTDPDEVARRWGERKSKPNMNYDKLSRALRYYYDKNIMTKVHGKRYAYKFDFHGIAQALQPHPPESSLYKYPSDLPYMGSYHAHPQKMNFVAPHPPALPVTSSSFFAAPNPYWNSPTGGIYPNTRLPTSHMPSHLGTYY"

    target_minimization_function = AffinityMinimization(
        smiles_generator, 30, affinity_predictor, erg_protein)
    qed_function = QEDMinimization(smiles_generator, 30)
    sa_function = SAMinimization(smiles_generator, 30)
    combined_minimization = CombinedMinimization(
        [target_minimization_function, qed_function, sa_function], 1,
        [0.75, 1, 0.5])
    target_optimizer = GPOptimizer(combined_minimization.evaluate)

    params = dict(
        dimensions=[(-5.0, 5.0)] * 256,
        acq_func="EI",
        n_calls=20,
        n_initial_points=19,
        initial_point_generator="random",
        random_state=1234,
    )
    logger.info("Optimisation parameters: {params}", params=params)

    # optimisation
    for j in range(5):
        res = target_optimizer.optimize(params)
        latent_point = torch.tensor([[res.x]])

        with open(results_file_name + "_LP" + str(j + 1) + ".pkl", "wb") as f:
            pickle.dump(latent_point, f, protocol=2)

        smile_set = set()

        while len(smile_set) < 20:
            smiles = smiles_generator.generate_smiles(
                latent_point.repeat(1, 30, 1))
            smile_set.update(set(smiles))
        smile_set = list(smile_set)

        pad_smiles_predictor = LeftPadding(
            affinity_predictor.smiles_padding_length,
            affinity_predictor.smiles_language.padding_index,
        )
        to_tensor = ToTensor(get_device())
        smiles_num = [
            torch.unsqueeze(
                to_tensor(
                    pad_smiles_predictor(
                        affinity_predictor.smiles_language.
                        smiles_to_token_indexes(smile))),
                0,
            ) for smile in smile_set
        ]

        smiles_tensor = torch.cat(smiles_num, dim=0)

        pad_protein_predictor = LeftPadding(
            affinity_predictor.protein_padding_length,
            affinity_predictor.protein_language.padding_index,
        )

        protein_num = torch.unsqueeze(
            to_tensor(
                pad_protein_predictor(
                    affinity_predictor.protein_language.
                    sequence_to_token_indexes([erg_protein]))),
            0,
        )
        protein_num = protein_num.repeat(len(smile_set), 1)

        with torch.no_grad():
            pred, _ = affinity_predictor(smiles_tensor, protein_num)
        affinities = torch.squeeze(pred, 1).numpy()

        sas = SAS()
        sa_scores = [sas(smile) for smile in smile_set]
        qed_scores = [qed(Chem.MolFromSmiles(smile)) for smile in smile_set]

        # save to file
        file = results_file_name + str(j + 1) + ".txt"
        logger.info("creating {file}", file=file)

        with open(file, "w") as f:
            f.write(
                f'{"point":<10}{"Affinity":<10}{"QED":<10}{"SA":<10}{"smiles":<15}\n'
            )
            for i in range(20):
                dat = [
                    i + 1, affinities[i], qed_scores[i], sa_scores[i],
                    smile_set[i]
                ]
                f.write(
                    f'{dat[0]:<10}{"%.3f"%dat[1]:<10}{"%.3f"%dat[2]:<10}{"%.3f"%dat[3]:<10}{dat[4]:<15}\n'
                )
Пример #19
0
    def step(self, action):
        """
        Perform a given action
        :param action:
        :return: reward of 1 if resulting molecule graph does not exceed valency,
        -1 if otherwise
        """
        # init
        info = {}  # info we care about
        mol_old = copy.deepcopy(self.mol)  # keep old mol
        total_atoms = self.mol.GetNumAtoms()

        # take action
        if action[0, 3] == 0 or self.counter < self.min_action:  # not stop
            stop = False
            if action[0, 1] >= total_atoms:
                self._add_atom(action[0, 1] - total_atoms)  # add new node
                action[0, 1] = total_atoms  # new node id
                self._add_bond(action)  # add new edge
            else:
                self._add_bond(action)  # add new edge
        else:  # stop
            stop = True

        # calculate intermediate rewards
        if self.check_valency():
            if self.mol.GetNumAtoms() + self.mol.GetNumBonds(
            ) - mol_old.GetNumAtoms() - mol_old.GetNumBonds() > 0:
                reward_step = self.reward_step_total / self.max_atom  # successfully add node/edge
                self.smile_list.append(self.get_final_smiles())
            else:
                reward_step = -self.reward_step_total / self.max_atom  # edge exist
        else:
            reward_step = -self.reward_step_total / self.max_atom  # invalid action
            self.mol = mol_old

        # calculate terminal rewards
        # TODO: add terminal action

        if self.is_conditional:
            terminate_condition = (
                self.mol.GetNumAtoms() >= self.max_atom -
                self.possible_atom_types.shape[0] - self.min_action
                or self.counter >= self.max_action
                or stop) and self.counter >= self.min_action
        else:
            terminate_condition = (self.mol.GetNumAtoms() >= self.max_atom -
                                   self.possible_atom_types.shape[0]
                                   or self.counter >= self.max_action
                                   or stop) and self.counter >= self.min_action
        if terminate_condition or self.force_final:
            # default reward
            reward_valid = 2
            reward_qed = 0
            reward_sa = 0
            reward_logp = 0
            reward_final = 0
            flag_steric_strain_filter = True
            flag_zinc_molecule_filter = True

            if not self.check_chemical_validity():
                reward_valid -= 5
            else:
                # final mol object where any radical electrons are changed to bonds to hydrogen
                final_mol = self.get_final_mol()
                s = Chem.MolToSmiles(final_mol, isomericSmiles=True)
                final_mol = Chem.MolFromSmiles(s)

                # mol filters with negative rewards
                if not steric_strain_filter(
                        final_mol
                ):  # passes 3D conversion, no excessive strain
                    reward_valid -= 1
                    flag_steric_strain_filter = False
                if not zinc_molecule_filter(
                        final_mol
                ):  # does not contain any problematic functional groups
                    reward_valid -= 1
                    flag_zinc_molecule_filter = False

                # property rewards
                try:
                    # 1. QED reward. Can have values [0, 1]. Higher the better
                    reward_qed += qed(final_mol) * self.qed_ratio
                    # 2. Synthetic accessibility reward. Values naively normalized to [0, 1]. Higher the better
                    sa = -1 * calculateScore(final_mol)
                    reward_sa += (sa + 10) / (10 - 1) * self.sa_ratio
                    # 3. Logp reward. Higher the better
                    # reward_logp += MolLogP(self.mol)/10 * self.logp_ratio
                    reward_logp += reward_penalized_log_p(
                        final_mol) * self.logp_ratio
                    if self.reward_type == 'logppen':
                        reward_final += reward_penalized_log_p(final_mol) / 3
                    elif self.reward_type == 'logp_target':
                        # reward_final += reward_target(final_mol,target=self.reward_target,ratio=0.5,val_max=2,val_min=-2,func=MolLogP)
                        # reward_final += reward_target_logp(final_mol,target=self.reward_target)
                        reward_final += reward_target_new(
                            final_mol,
                            MolLogP,
                            x_start=self.reward_target,
                            x_mid=self.reward_target + 0.25)
                    elif self.reward_type == 'qed':
                        reward_final += reward_qed * 2
                    elif self.reward_type == 'qedsa':
                        reward_final += (reward_qed * 1.5 + reward_sa * 0.5)
                    elif self.reward_type == 'qed_target':
                        # reward_final += reward_target(final_mol,target=self.reward_target,ratio=0.1,val_max=2,val_min=-2,func=qed)
                        reward_final += reward_target_qed(
                            final_mol, target=self.reward_target)
                    elif self.reward_type == 'mw_target':
                        # reward_final += reward_target(final_mol,target=self.reward_target,ratio=40,val_max=2,val_min=-2,func=rdMolDescriptors.CalcExactMolWt)
                        # reward_final += reward_target_mw(final_mol,target=self.reward_target)
                        reward_final += reward_target_new(
                            final_mol,
                            rdMolDescriptors.CalcExactMolWt,
                            x_start=self.reward_target,
                            x_mid=self.reward_target + 25)

                    elif self.reward_type == 'gan':
                        reward_final = 0
                    else:
                        print('reward error!')
                        reward_final = 0

                except:  # if any property reward error, reset all
                    print('reward error')

            new = True  # end of episode
            if self.force_final:
                reward = reward_final
            else:
                reward = reward_step + reward_valid + reward_final
            info['smile'] = self.get_final_smiles()
            if self.is_conditional:
                info['reward_valid'] = self.conditional[
                    -1]  # FIXME temp change
            else:
                info['reward_valid'] = reward_valid
            info['reward_qed'] = reward_qed
            info['reward_sa'] = reward_sa
            info['final_stat'] = reward_final
            info['reward'] = reward
            info['flag_steric_strain_filter'] = flag_steric_strain_filter
            info['flag_zinc_molecule_filter'] = flag_zinc_molecule_filter
            info['stop'] = stop

        # use stepwise reward
        else:
            new = False
            # print('counter', self.counter, 'new', new, 'reward_step', reward_step)
            reward = reward_step

        # get observation
        ob = self.get_observation()

        self.counter += 1
        if new:
            self.counter = 0

        return ob, reward, new, info
Пример #20
0
def generate(model, init_dataloader, n_atom, n_atom_type, n_edge_type, device,
             atomic_num_list):
    optim_dict = dict()
    save_smiles_dict = dict()
    parameters = model.parameters()

    ### Langevin dynamics
    for i, batch in enumerate(tqdm(init_dataloader)):
        gen_x = batch[0][0].to(device)
        gen_adj = batch[0][1].to(device)
        print(batch[0][0].shape)
        original_mols = turn_valid(gen_adj,
                                   gen_x,
                                   atomic_num_list,
                                   correct_validity=args.correct_validity)

        gen_x.requires_grad = True
        gen_adj.requires_grad = True

        requires_grad(parameters, False)
        model.eval()

        noise_x = torch.randn(gen_x.shape[0],
                              n_atom,
                              n_atom_type,
                              device=device)  # (10000, 9, 5)
        noise_adj = torch.randn(gen_adj.shape[0],
                                n_edge_type,
                                n_atom,
                                n_atom,
                                device=device)  #(10000, 4, 9, 9)

        for k in tqdm(range(args.sample_step)):

            noise_x.normal_(0, 0.005)
            noise_adj.normal_(0, 0.005)
            gen_x.data.add_(noise_x.data)
            gen_adj.data.add_(noise_adj.data)

            gen_out = model(gen_adj, gen_x)
            gen_out.sum().backward()
            gen_x.grad.data.clamp_(-0.1, 0.1)
            gen_adj.grad.data.clamp_(-0.1, 0.1)

            gen_x.data.add_(-args.step_size, gen_x.grad.data)
            gen_adj.data.add_(-args.step_size, gen_adj.grad.data)

            gen_x.grad.detach_()
            gen_x.grad.zero_()
            gen_adj.grad.detach_()
            gen_adj.grad.zero_()

            gen_x.data.clamp_(0, 1 + args.c)
            gen_adj.data.clamp_(0, 1)

            # if k % 2 == 0:
            gen_x_t = copy.deepcopy(gen_x)
            gen_adj_t = copy.deepcopy(gen_adj)
            gen_adj_t = gen_adj_t + gen_adj_t.permute(
                0, 1, 3, 2)  # A+A^T is a symmetric matrix
            gen_adj_t = gen_adj_t / 2
            gen_adj_t = gen_adj_t.softmax(
                dim=1)  ### Make all elements to be larger than 0
            max_bond = gen_adj_t.max(dim=1).values.reshape(
                args.batch_size, -1, n_atom, n_atom)  # (10000, 1, 9, 9)
            gen_adj_t = torch.floor(
                gen_adj_t / max_bond
            )  # (10000, 4, 9, 9) /  (10000, 1, 9, 9) -->  (10000, 4, 9, 9)
            val_res = turn_valid(gen_adj_t,
                                 gen_x_t,
                                 atomic_num_list,
                                 correct_validity=args.correct_validity)
            assert len(val_res['valid_mols']) == len(
                original_mols['valid_mols'])

            for mol_idx in range(len(val_res['valid_mols'])):
                if val_res['valid_mols'][mol_idx] is not None:
                    tmp_mol = val_res['valid_mols'][mol_idx]
                    tmp_smiles = val_res['valid_smiles'][mol_idx]
                    o_mol = original_mols['valid_mols'][mol_idx]
                    o_smiles = original_mols['valid_smiles'][mol_idx]
                    # calculate imp
                    if args.property_name == 'qed':
                        imp_p = qed(tmp_mol) - qed(o_mol)
                    elif args.property_name == 'plogp':
                        try:
                            imp_p = calculate_min_plogp(
                                tmp_mol) - calculate_min_plogp(o_mol)
                            # calculate sim
                            current_sim = reward_target_molecule_similarity(
                                tmp_mol, o_mol)
                            update_optim_dict(optim_dict, o_smiles, tmp_smiles,
                                              imp_p, current_sim)
                            update_save_dict(save_smiles_dict, o_smiles,
                                             tmp_smiles, imp_p, current_sim)
                        except:
                            # print('plogp calculate error!')
                            pass

    return optim_dict, save_smiles_dict