示例#1
0
    def setUp(self):
        parser = ArgumentParser()
        add_train_args(parser)
        args = parser.parse_args([])
        args.data_path = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), 'delaney_toy.csv')
        args.dataset_type = 'regression'
        args.batch_size = 2
        args.hidden_size = 5
        args.epochs = 1
        args.quiet = True
        self.temp_dir = TemporaryDirectory()
        args.save_dir = self.temp_dir.name
        logger = create_logger(name='train',
                               save_dir=args.save_dir,
                               quiet=args.quiet)
        modify_train_args(args)
        cross_validate(args, logger)
        clear_cache()

        parser = ArgumentParser()
        add_predict_args(parser)
        args = parser.parse_args([])
        args.batch_size = 2
        args.checkpoint_dir = self.temp_dir.name
        args.preds_path = NamedTemporaryFile().name
        args.test_path = os.path.join(
            os.path.dirname(os.path.abspath(__file__)),
            'delaney_toy_smiles.csv')
        self.args = args
def predict_properties(pred_smiles: List[Optional[str]], checkpoint_dir: str, computed_prop: str) -> List[float]:
    # Check that exactly one of checkpoint_dir and computed_prop is provided
    assert (checkpoint_dir is None) != (computed_prop is None)

    pred_smiles = [smiles for smiles in pred_smiles if smiles is not None]

    # Create and modify predict args
    parser = ArgumentParser()
    add_predict_args(parser)
    args = parser.parse_args([])

    if checkpoint_dir is not None:
        args.test_path = 'None'
        args.checkpoint_dir = checkpoint_dir
        update_checkpoint_args(args)
        args.quiet = True

        print('Make predictions')
        with NamedTemporaryFile() as temp_file:
            args.preds_path = temp_file.name
            property_predictions = make_predictions(args, smiles=pred_smiles)

        property_predictions = [property_prediction[0] for property_prediction in property_predictions]
    else:
        if computed_prop == 'penalized_logp':
            scorer = penalized_logp
        elif computed_prop == 'logp':
            scorer = logp
        elif computed_prop == 'qed':
            scorer = qed
        elif computed_prop == 'sascore':
            scorer = sascore
        elif computed_prop == 'drd2':
            scorer = drd2
        else:
            raise ValueError(f'Computed property "{computed_prop}" not supported')

        property_predictions = [scorer(s) for s in pred_smiles]

    return property_predictions
示例#3
0
def predict():
    if request.method == 'GET':
        return render_template('predict.html',
                               checkpoints=get_checkpoints(),
                               cuda=app.config['CUDA'],
                               gpus=app.config['GPUS'])

    # Get arguments
    checkpoint_name = request.form['checkpointName']

    if 'data' in request.files:
        # Upload data file with SMILES
        show_file_upload = True
        data = request.files['data']
        data_name = secure_filename(data.filename)
        data_path = os.path.join(app.config['TEMP_FOLDER'], data_name)
        data.save(data_path)

        smiles = []
        with open(data_path, 'r') as f:
            header = f.readline()
            try:  # if there's no header, add the smiles in the first line
                possible_smiles = header.strip().split(',')[0]
                mol = Chem.MolFromSmiles(possible_smiles)
                smiles.append(possible_smiles)
            except:
                pass
            for line in f:
                smiles.append(line.strip().split(',')[0])
    else:
        show_file_upload = False
        smiles = request.form['smiles']
        smiles = smiles.split()

    checkpoint_path = os.path.join(app.config['CHECKPOINT_FOLDER'],
                                   checkpoint_name)
    task_names = load_task_names(checkpoint_path)
    gpu = request.form.get('gpu', None)

    # Create and modify args
    parser = ArgumentParser()
    add_predict_args(parser)
    args = parser.parse_args()

    preds_path = os.path.join(app.config['TEMP_FOLDER'],
                              app.config['PREDICTIONS_FILENAME'])
    args.preds_path = preds_path
    args.checkpoint_paths = [checkpoint_path]
    if gpu is not None:
        if gpu == 'None':
            args.no_cuda = True
        else:
            args.gpu = int(gpu)

    invalid_smiles_warning = "Invalid SMILES String"
    if len(smiles) > 0:
        # Run prediction
        preds = make_predictions(args,
                                 smiles=smiles,
                                 invalid_smiles_warning=invalid_smiles_warning)
    else:
        preds = []

    return render_template(
        'predict.html',
        checkpoints=get_checkpoints(),
        cuda=app.config['CUDA'],
        gpus=app.config['GPUS'],
        predicted=True,
        smiles=smiles,
        num_smiles=min(10, len(smiles)),
        show_more=max(0,
                      len(smiles) - 10),
        task_names=task_names,
        num_tasks=len(task_names),
        preds=preds,
        show_file_upload=show_file_upload,
        warning="List contains invalid SMILES strings"
        if invalid_smiles_warning in preds else None,
        error="No SMILES strings given" if len(preds) == 0 else None)
示例#4
0
def predict():
    """Renders the predict page and makes predictions if the method is POST."""
    if request.method == 'GET':
        return render_predict()

    # Get arguments
    ckpt_id = request.form['checkpointName']

    if request.form['textSmiles'] != '':
        smiles = request.form['textSmiles'].split()
    elif request.form['drawSmiles'] != '':
        smiles = [request.form['drawSmiles']]
    else:
        print(" GOT HERE")
        # Upload data file with SMILES
        data = request.files['data']
        data_name = secure_filename(data.filename)
        data_path = os.path.join(app.config['TEMP_FOLDER'], data_name)
        data.save(data_path)

        # Check if header is smiles
        possible_smiles = get_header(data_path)[0]
        smiles = [possible_smiles] if str_to_mol(possible_smiles) is not None else []

        # Get remaining smiles
        smiles.extend(get_smiles(data_path))

    models = db.get_models(ckpt_id)
    model_paths = [os.path.join(app.config['CHECKPOINT_FOLDER'], f'{model["id"]}.pt') for model in models]

    task_names = load_task_names(model_paths[0])
    num_tasks = len(task_names)
    gpu = request.form.get('gpu')

    # Create and modify args
    parser = ArgumentParser()
    add_predict_args(parser)
    args = parser.parse_args([])

    preds_path = os.path.join(app.config['TEMP_FOLDER'], app.config['PREDICTIONS_FILENAME'])
    args.test_path = 'None'  # TODO: Remove this hack to avoid assert crashing in modify_predict_args
    args.preds_path = preds_path
    args.checkpoint_paths = model_paths
    if gpu is not None:
        if gpu == 'None':
            args.no_cuda = True
        else:
            args.gpu = int(gpu)

    modify_predict_args(args)

    # Run predictions
    preds = make_predictions(args, smiles=smiles)

    if all(p is None for p in preds):
        return render_predict(errors=['All SMILES are invalid'])

    # Replace invalid smiles with message
    invalid_smiles_warning = "Invalid SMILES String"
    preds = [pred if pred is not None else [invalid_smiles_warning] * num_tasks for pred in preds]

    return render_predict(predicted=True,
                          smiles=smiles,
                          num_smiles=min(10, len(smiles)),
                          show_more=max(0, len(smiles)-10),
                          task_names=task_names,
                          num_tasks=len(task_names),
                          preds=preds,
                          warnings=["List contains invalid SMILES strings"] if None in preds else None,
                          errors=["No SMILES strings given"] if len(preds) == 0 else None)
import argparse
from ml_QM_GNN.WLN.data_loading import Graph_DataLoader
from ml_QM_GNN.graph_utils.mol_graph import initialize_qm_descriptors
from scipy.special import softmax
from rdkit import rdBase
from tqdm import tqdm

from rdkit import Chem

#find chemprop root path
from chemprop.parsing import add_predict_args

rdBase.DisableLog('rdApp.warning')

parser = argparse.ArgumentParser()
add_predict_args(parser)
parser.add_argument(
    '-r',
    '--restart',
    action='store_true',
    help='restart the training using the saved the checkpoint file')
parser.add_argument('-p',
                    '--predict',
                    action='store_true',
                    help='predict reactivity for a given .csv file')
parser.add_argument('-m',
                    '--model',
                    default='ml_QM_GNN',
                    choices=['ml_QM_GNN', 'QM_GNN', 'GNN'],
                    help='model can be used')
parser.add_argument('--model_dir',