예제 #1
0
파일: test.py 프로젝트: zhenghl2/chemprop
 def test_predict_compound_names(self):
     try:
         self.args.test_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'delaney_toy_smiles_names.csv')
         self.args.use_compound_names = True
         modify_predict_args(self.args)
         make_predictions(self.args)
     except:
         self.fail('predict_compound_names')
def predict_desc(args):
    import chemprop
    chemprop_root = os.path.dirname(os.path.dirname(chemprop.__file__))

    #trick chemprop
    args.test_path = 'foo'
    args.preds_path = 'foo'
    args.checkpoint_path = os.path.join(chemprop_root, 'trained_model',
                                        'QM_137k.pt')
    modify_predict_args(args)

    def num_atoms_bonds(smiles):
        m = Chem.MolFromSmiles(smiles)

        m = Chem.AddHs(m)

        return len(m.GetAtoms()), len(m.GetBonds())

    # predict descriptors for reactants in the reactions
    reactivity_data = pd.read_csv(args.data_path, index_col=0)
    reactants = set()
    for _, row in reactivity_data.iterrows():
        rs, _, _ = row['rxn_smiles'].split('>')
        rs = rs.split('.')
        for r in rs:
            reactants.add(r)
    reactants = list(reactants)

    print('Predicting descriptors for reactants...')
    test_preds, test_smiles = make_predictions(args, smiles=reactants)

    partial_charge = test_preds[0]
    partial_neu = test_preds[1]
    partial_elec = test_preds[2]
    NMR = test_preds[3]

    bond_order = test_preds[4]
    bond_distance = test_preds[5]

    n_atoms, n_bonds = zip(*[num_atoms_bonds(x) for x in reactants])

    partial_charge = np.split(partial_charge.flatten(),
                              np.cumsum(np.array(n_atoms)))[:-1]
    partial_neu = np.split(partial_neu.flatten(),
                           np.cumsum(np.array(n_atoms)))[:-1]
    partial_elec = np.split(partial_elec.flatten(),
                            np.cumsum(np.array(n_atoms)))[:-1]
    NMR = np.split(NMR.flatten(), np.cumsum(np.array(n_atoms)))[:-1]

    bond_order = np.split(bond_order.flatten(),
                          np.cumsum(np.array(n_bonds)))[:-1]
    bond_distance = np.split(bond_distance.flatten(),
                             np.cumsum(np.array(n_bonds)))[:-1]

    df = pd.DataFrame({
        'smiles': reactants,
        'partial_charge': partial_charge,
        'fukui_neu': partial_neu,
        'fukui_elec': partial_elec,
        'NMR': NMR,
        'bond_order': bond_order,
        'bond_length': bond_distance
    })

    invalid = check_chemprop_out(df)
    # FIXME remove invalid molecules from reaction dataset
    print(invalid)

    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)

    df.to_pickle(os.path.join(args.output_dir, 'reactants_descriptors.pickle'))
    save_dir = args.model_dir

    if not args.predict:
        df, scalers = min_max_normalize(df)
        pickle.dump(scalers,
                    open(os.path.join(save_dir, 'scalers.pickle'), 'wb'))
    else:
        scalers = pickle.load(
            open(os.path.join(save_dir, 'scalers.pickle'), 'rb'))
        df, _ = min_max_normalize(df, scalers)

    df.to_pickle(
        os.path.join(args.output_dir, 'reactants_descriptors_norm.pickle'))

    return df
예제 #3
0
 def test_predict(self):
     try:
         modify_predict_args(self.args)
         make_predictions(self.args)
     except:
         self.fail('predict')
예제 #4
0
def predict():
    """Renders the predict page and makes predictions if the method is POST."""
    if request.method == 'GET':
        return render_predict()

    # Get arguments
    ckpt_id = request.form['checkpointName']

    if request.form['textSmiles'] != '':
        smiles = request.form['textSmiles'].split()
    elif request.form['drawSmiles'] != '':
        smiles = [request.form['drawSmiles']]
    else:
        print(" GOT HERE")
        # Upload data file with SMILES
        data = request.files['data']
        data_name = secure_filename(data.filename)
        data_path = os.path.join(app.config['TEMP_FOLDER'], data_name)
        data.save(data_path)

        # Check if header is smiles
        possible_smiles = get_header(data_path)[0]
        smiles = [possible_smiles] if Chem.MolFromSmiles(possible_smiles) is not None else []

        # Get remaining smiles
        smiles.extend(get_smiles(data_path))

    models = db.get_models(ckpt_id)
    model_paths = [os.path.join(app.config['CHECKPOINT_FOLDER'], f'{model["id"]}.pt') for model in models]

    task_names = load_task_names(model_paths[0])
    num_tasks = len(task_names)
    gpu = request.form.get('gpu')

    # Create and modify args
    args = load_args(model_paths[0])

    if args.features_path != None:
        args.features_generator = ["rdkit_2d_normalized"]
        args.features_path = None

    preds_path = os.path.join(app.config['TEMP_FOLDER'], app.config['PREDICTIONS_FILENAME'])
    args.test_path = 'None'  # TODO: Remove this hack to avoid assert crashing in modify_predict_args
    args.preds_path = preds_path
    args.checkpoint_paths = model_paths
    if gpu is not None:
        if gpu == 'None':
            args.no_cuda = True
        else:
            args.gpu = int(gpu)

    modify_predict_args(args)

    # Run predictions
    preds = make_predictions(args, smiles=smiles)

    if all(p is None for p in preds):
        return render_predict(errors=['All SMILES are invalid'])

    # Replace invalid smiles with message
    invalid_smiles_warning = "Invalid SMILES String"
    preds = [pred if pred is not None else [invalid_smiles_warning] * num_tasks for pred in preds]

    return render_predict(predicted=True,
                          smiles=smiles,
                          num_smiles=min(10, len(smiles)),
                          show_more=max(0, len(smiles)-10),
                          task_names=task_names,
                          num_tasks=len(task_names),
                          preds=preds,
                          warnings=["List contains invalid SMILES strings"] if None in preds else None,
                          errors=["No SMILES strings given"] if len(preds) == 0 else None)
예제 #5
0
def predict():
    if request.method == 'GET':
        return render_predict()

    # Get arguments
    checkpoint_name = request.form['checkpointName']

    if 'data' in request.files:
        # Upload data file with SMILES
        data = request.files['data']
        data_name = secure_filename(data.filename)
        data_path = os.path.join(app.config['TEMP_FOLDER'], data_name)
        data.save(data_path)

        # Check if header is smiles
        possible_smiles = get_header(data_path)[0]
        smiles = [possible_smiles
                  ] if Chem.MolFromSmiles(possible_smiles) is not None else []

        # Get remaining smiles
        smiles.extend(get_smiles(data_path))
    elif request.form['textSmiles'] != '':
        smiles = request.form['textSmiles'].split()
    else:
        smiles = [request.form['drawSmiles']]

    checkpoint_path = os.path.join(app.config['CHECKPOINT_FOLDER'],
                                   checkpoint_name)
    task_names = load_task_names(checkpoint_path)
    num_tasks = len(task_names)
    gpu = request.form.get('gpu')

    # Create and modify args
    parser = ArgumentParser()
    add_predict_args(parser)
    args = parser.parse_args()

    preds_path = os.path.join(app.config['TEMP_FOLDER'],
                              app.config['PREDICTIONS_FILENAME'])
    args.test_path = 'None'  # TODO: Remove this hack to avoid assert crashing in modify_predict_args
    args.preds_path = preds_path
    args.checkpoint_path = checkpoint_path
    args.write_smiles = True
    if gpu is not None:
        if gpu == 'None':
            args.no_cuda = True
        else:
            args.gpu = int(gpu)

    modify_predict_args(args)

    # Run predictions
    preds = make_predictions(args, smiles=smiles)

    if all(p is None for p in preds):
        return render_predict(errors=['All SMILES are invalid'])

    # Replace invalid smiles with message
    invalid_smiles_warning = "Invalid SMILES String"
    preds = [
        pred if pred is not None else [invalid_smiles_warning] * num_tasks
        for pred in preds
    ]

    return render_predict(
        predicted=True,
        smiles=smiles,
        num_smiles=min(10, len(smiles)),
        show_more=max(0,
                      len(smiles) - 10),
        task_names=task_names,
        num_tasks=len(task_names),
        preds=preds,
        warnings=["List contains invalid SMILES strings"]
        if None in preds else None,
        errors=["No SMILES strings given"] if len(preds) == 0 else None)