def extract_st(cms_file, trj_dir, asl='all', frames=1):
    """
    Extract N frames from the trajectory and save them in structure.Structure

    :param cms_file:
    :param trj_dir:
    :param asl:
    :param frames: if int: number of frames uniformly distributed, if iterable: list fo frames
    :return structures: list of schrodinger.structure.Structure
    :type structures: list
    """
    structures = []
    msys_model, cms_model = topo.read_cms(cms_file)
    frame_arch = traj.read_traj(trj_dir)
    if type(frames) == int:
        for f in np.linspace(0, len(frame_arch) - 1, frames).astype(int):
            st = topo.update_cms(cms_model, frame_arch[f])
            st = st.extract(evaluate_asl(cms_model, asl))
            st.title = 'Frame {}'.format(f)
            structures.append(st)
    else:
        try:
            for f in frames:
                st = topo.update_cms(cms_model, frame_arch[f])
                st = st.extract(asl)
                st.title = 'Frame {}'.format(f)
                structures.append(st)
        except Exception as e:
            raise RuntimeError(e)

    return structures
Beispiel #2
0
    def __init__(self, cms_file, trj_dir, queue, frames=None, params=None):
        """
        Docstring
        :param cms_file:
        :param trj_dir:
        :param frames:
        :param params:
        """
        multiprocessing.Process.__init__(self)

        self.queue = queue

        # load cms_model
        self.msys_model, self.cms_model = topo.read_cms(str(cms_file))

        # Get framelist
        if frames is not None:
            self.frame_list = [
                frame for (i, frame) in enumerate(traj.read_traj(str(trj_dir)))
                if i in frames
            ]
        else:
            self.frame_list = traj.read_traj(str(trj_dir))

        self.total_frames = len(self.frame_list)

        self.align = Superimposer()

        # Calculation parameters
        self.calculation_parameters = []

        if params is not None:
            for param in params:
                self.add_param(**param)
Beispiel #3
0
def get_desmond_cms(api, stid, dir=None):
    """

    :param api: 
    :param stid: 
    :param dir: 
    :return: 
    """

    if dir is None:
        tempdir = tempfile.TemporaryDirectory()
        atexit.register(tempdir.cleanup)
        dir = tempdir.name

    filename = os.path.join(dir, 'desmond.cms')
    try:
        resp = api.get_structure_file(stid, file_type='desmond_cms')
        resp.raise_for_status()
    except Exception as e:
        resp.close()
        raise HTTPError(e)
    with open(filename, 'wb') as fh:
        fh.write(resp.content)
    resp.close()
    return topo.read_cms(filename)
Beispiel #4
0
def preprocess_hbond(api, dataset, data_type=None, no_ligand=False):
    """

    :param api:
    :param dataset:
    :param data_type: Placeholder
    :param no_ligand: do not rename ligand residues
    :return:
    """
    hbonds_df = pd.DataFrame()

    if api is None:
        dataset.dropna(axis=0,
                       subset=['desmond_cms', 'trj_hbonds'],
                       inplace=True)

    for index, stid in zip(dataset.index, dataset.loc[:, 'structure_id']):
        # Get data
        if api is None:
            msys_model, cms_model = topo.read_cms(dataset.loc[index,
                                                              'desmond_cms'])
            data = pd.read_csv(dataset.loc[index, 'trj_hbonds'],
                               sep=',',
                               index_col=0)
            data.index = np.arange(data.shape[0])
        else:
            msys_model, cms_model = get_desmond_cms(
                api,
                stid,
            )
            data = get_trj_hbonds(api, stid)

        # Get ligand
        if no_ligand:
            ligand_resid = None
        elif 'ligand_resnum' not in dataset.columns and 'ligand_chain' not in dataset.columns:
            ligands = find_ligands(st=cms_model)
            assert (len(ligands) == 1)
            ligand_resid = None
            for atm in ligands[0].st.atom:
                if ligand_resid is None:
                    ligand_resid = (atm.resnum, atm.chain.strip())
                else:
                    if (atm.resnum, atm.chain.strip()) != ligand_resid:
                        raise ValueError(
                            'Structure {}\nFound multiple ligand residues!'.
                            format(stid))
        elif 'ligand_resnum' in dataset.columns and 'ligand_chain' in dataset.columns:
            ligand_resid = (int(dataset.loc[index, 'ligand_resnum']),
                            dataset.loc[index, 'ligand_chain'])

        frequency_df = _load_hbonds(cms_model, data, ligand_resid=ligand_resid)
        for fi in frequency_df.index:
            hbonds_df.loc[stid, fi] = frequency_df.loc[fi, 'frequency']
    hbonds_df.fillna(0, inplace=True)
    return hbonds_df
Beispiel #5
0
def preprocess_rms(api, dataset, data_type=None, no_ligand=False):
    """

    :param api:
    :type api: pldbclient.api_client.Api
    :param dataset:
    :type dataset: pandas.Dataframe
    :param data_type: Placeholder
    :param no_lignad: Placeholder
    :return:
    """

    if api is None:
        dataset.dropna(axis=0, subset=['desmond_cms', 'trj_rms'], inplace=True)

    df_rms = pd.DataFrame(index=dataset.structure_id)

    for index, stid in zip(dataset.index, dataset.structure_id):
        with tempfile.TemporaryDirectory() as tempdir:
            if api is None:
                msys_model, cms_model = topo.read_cms(
                    dataset.loc[index, 'desmond_cms'])
                with open(dataset.loc[index, 'trj_rms'], 'r') as fh:
                    trj_rms = json.load(fh)
            else:
                # Load desomd cms
                msys_model, cms_model = get_desmond_cms(api, stid, dir=tempdir)
                # Load trj_rms
                trj_rms = get_trj_rms(api, stid, dir=tempdir)

        # Preprocess data
        for data in trj_rms:
            if data['name'] == 'calpha rmsd':
                df_rms.loc[stid, 'mean_rmsd'] = np.mean(data['results'])
            if data['name'] == 'calpha rmsf':
                for aid, x in zip(data['atom_ids'], data['results']):
                    df_rms.loc[stid, '{}:{}'.format(*aid[:-1])] = x
            if data['name'] == 'ligand_rmsf':
                aid2an = {}
                for n in evaluate_asl(cms_model, 'ligand and not a.element H'):
                    aid2an['{}:{}'.format(
                        str(cms_model.atom[n].resnum),
                        cms_model.atom[n].pdbname.strip())] = n
                for aid, x in zip(data['atom_ids'], data['results']):
                    an = aid2an.get('{}:{}'.format(*aid[1:]))
                    if an is None:  # Legacy check, this is true if calculation included pseudoatoms
                        continue
                    label = '#{}'.format(an)
                    df_rms.loc[stid, label] = x
    return df_rms
Beispiel #6
0
def preprocess_torsion(api, dataset, data_type=None, no_ligand=False):
    """

    :param api:
    :param dataset:
    :param data_type: Placeholder
    :param no_ligand: Placeholder
    :return:
    """

    if api is None:  # Only process structures which contain all input files
        dataset.dropna(axis=0,
                       subset=['desmond_cms', 'trj_torsion'],
                       inplace=True)

    torsion_entropy = pd.DataFrame(index=dataset.structure_id)

    bins = np.arange(-180, 198, 18)  # Bin size for calculating entropy

    cwd = os.path.abspath('./')

    for index, stid in zip(dataset.index, dataset.structure_id):
        with tempfile.TemporaryDirectory() as tempdir:

            os.chdir(tempdir)

            if api is None:
                msys_model, cms_model = topo.read_cms(
                    dataset.loc[index, 'desmond_cms'])
                trj_torsion = dataset.loc[index, 'trj_torsion']
            else:
                # Load cms_model
                msys_model, cms_model = get_desmond_cms(api, stid, dir=tempdir)
                trj_torsion = get_trj_torsion(api, stid, dir=tempdir)

            with tarfile.open(trj_torsion, 'r:gz') as tar:
                tar.extractall(path=tempdir)

            torsion_ids = pd.read_csv('torsion_ids.csv', sep=',', index_col=0)
            for tid in torsion_ids.index:
                torsion_label = label_torsion(
                    cms_model, *list(map(int, torsion_ids.loc[tid, :])))
                series = np.genfromtxt('torsion_{}.csv'.format(int(tid)),
                                       delimiter=',')
                counts, bins = np.histogram(series, bins=bins)
                freq = counts / np.sum(counts)
                torsion_entropy.loc[stid, torsion_label] = entropy(freq)
        os.chdir(cwd)
    return torsion_entropy
Beispiel #7
0
 def from_files(cls, cms_fname, traj_fname):
     from schrodinger.application.desmond.packages import topo, traj
     _, model = topo.read_cms(cms_fname)
     traj = traj.read_traj(traj_fname)
     return cls(model, traj)
Beispiel #8
0
def merge_torsion(data, dataset=None, api=None):
    """

    :param data:
    :type data: pd.DataFrame
    :param dataset:
    :type dataset: pd.DataFrame
    :param api:
    :type api: pldbclient.api_client.Api
    :return:
    """
    if len(data) == 1:
        df_raw = data[0]  # type: pd.DataFrame
    else:
        df_raw = pd.concat(data, sort=False)  # type: pd.DataFrame

    ligand_col = [c for c in df_raw.columns if c[0] == '#'
                  ]  # ligand_torsion columns are indicated by '#<a1:a2:a3:a4>'

    # Get structures
    cms_models = []
    for stid in df_raw.index:
        with tempfile.TemporaryDirectory() as tempdir:
            try:
                resp = api.get_structure_file(structure_id=stid,
                                              file_type='desmond_cms')
                resp.raise_for_status()
            except HTTPError as e:
                resp.close()
                logger.error(e)

            with open(os.path.join(tempdir, 'desmond.cms'), 'wb') as fh:
                fh.write(resp.content)
            resp.close()
            msys_model, cms_model = topo.read_cms(
                os.path.join(tempdir, 'desmond.cms'))
        cms_models.append(cms_model)

    common_atoms, ligand_st = get_min_common_substructure(
        cms_models, return_st=True, return_common_atoms=True)
    if not all(common_atoms
               ):  # translate: if common atoms is a list of empty lists
        logger.warning('Ligands do not share a common substructure')
        df_raw.drop(ligand_col, axis=1, inplace=True)
        return df_raw, None
    else:
        df_merged = df_raw.drop(ligand_col, axis=1)  # type: pd.DataFrame
        # Assign common atom ids
        cid_map = dict([(stid, {}) for stid in df_merged.index])
        for i, atom_array in enumerate(zip(*common_atoms)):
            for stid, aid in zip(df_merged.index, atom_array):
                cid_map[stid][aid] = i + 1  # Because atom indices start at 1
        for stid in df_merged.index:
            for torsion_atoms in ligand_col:
                if not pd.isna(df_raw.loc[stid, torsion_atoms]):
                    x = df_raw.loc[stid, torsion_atoms]
                    torsion_atoms = list(map(int, torsion_atoms[1:].split(
                        ':')))  # torsion ids == '#a1:a2:a3:a4'
                    if all([aid in cid_map[stid] for aid in torsion_atoms
                            ]):  # If aid has a common atom id
                        torsion_bond = sorted([
                            cid_map[stid][torsion_atoms[1]],
                            cid_map[stid][torsion_atoms[2]]
                        ])
                        label = 'LIG:{}:{}'.format(*torsion_bond)
                        df_merged.loc[stid, label] = x
        return df_merged, ligand_st
Beispiel #9
0
def merge_rms(data, dataset=None, api=None):
    """
    Merge the trj_rms data.
    Merger can also be called only one data file.
    :param data:
    :type data: pd.DataFrame
    :param dataset:
    :type dataset: pd.DataFrame
    :param api:
    :type api: pldbclient.api_client.Api
    :return:
    """
    if len(data) == 1:
        df_raw = data[0]
    else:
        df_raw = pd.concat(data, sort=False)

    if api is None:
        dataset.dropna(axis=0, subset=['desmond_cms', 'trj_rms'], inplace=True)
        ligand_ids = dataset.loc[:, ['ligand_resnum', 'ligand_chain']]
    else:
        ligand_ids = None

    ligand_col = [c for c in df_raw.columns if c[0] == '#'
                  ]  # ligand_rmsf columns are indicated by '#<atom_id>'

    # Get structures
    cms_models = []
    for index in df_raw.index:
        if api is None:
            msys_model, cms_model = topo.read_cms(dataset.loc[index,
                                                              'desmond_cms'])
        else:
            msys_model, cms_model = get_desmond_cms(api, index)
        cms_models.append(cms_model)

    # Get a list (of lists) of common atoms
    common_atoms, ligand_st = get_min_common_substructure(
        cms_models,
        ligand_ids=ligand_ids,
        return_st=True,
        return_common_atoms=True)
    if not all(common_atoms
               ):  # translate: if common atoms is a list of empty lists
        logger.warning('Ligands do not share a common substructure')
        df_raw.drop(ligand_col, axis=1, inplace=True)
        return df_raw, None
    else:
        df_merged = df_raw.drop(ligand_col, axis=1)
        # Assign common atom ids
        cid_map = dict([(stid, {}) for stid in df_merged.index])
        for i, atom_array in enumerate(zip(*common_atoms)):
            for stid, aid in zip(df_merged.index, atom_array):
                cid_map[stid][aid] = i + 1  # Because atom indices start at 1
        for stid in df_merged.index:
            for aid in ligand_col:
                if not pd.isna(df_raw.loc[stid, aid]):
                    x = df_raw.loc[stid, aid]
                    aid = int(aid[1:])  # Drop the preceeding '#'
                    if aid in cid_map[stid]:  # If aid has a common atom id
                        label = 'LIG:{}'.format(cid_map[stid][aid])
                        df_merged.loc[stid, label] = x
        return df_merged, ligand_st
Beispiel #10
0
def _process(structure_dict):
    """
    Docstring
    :param structure_dict:
    :return:
    """

    fork = None
    # Check if transformers is called as part of a pipeline
    if 'pipeline' in structure_dict['custom']:
        pipeline = structure_dict['custom']['pipeline']
        fork = [
            pipeline[0],
        ]
        if len(pipeline) == 1:
            del (structure_dict['custom']['pipeline'])
        else:
            structure_dict['custom']['pipeline'] = pipeline[1:]

    outname = '{}_trj_rms'.format(structure_dict['structure']['code'])
    outfile = outname + '.json'
    results = []
    # Load simulation files
    cms_file = structure_dict['files']['desmond_cms']
    trjtar = structure_dict['files']['desmond_trjtar']

    # If run from command line it does not make sense to provide a tarfile
    if os.path.isdir(trjtar):
        trj_dir = trjtar
    elif tarfile.is_tarfile(trjtar):
        with tarfile.open(name=trjtar, mode='r:gz') as tfile:
            tfile.extractall()
            logger.info('extracting frameset')
            trj_dir = tfile.getnames()[0]
    else:
        raise RuntimeError('trjtar is neither a directory nor a tarfile')

    # Check whether a ligand exists and add ligand rmsd to set of calculations
    calculation_param = copy.copy(CALCULATION_PARAM)
    ligand_mae = structure_dict['files'].get('ligand_mae')
    if ligand_mae is None:
        msys_model, cms_model = topo.read_cms(str(cms_file))
        ligands = find_ligands(cms_model)
        if len(ligands) != 0:
            for ligand in ligands:
                ligand_asl = ' or '.join([
                    '(r. {} and c. "{}" and not a.element H )'.format(
                        res.resnum, res.chain) for res in ligand.st.residue
                ])
                calculation_param.append({
                    'name': 'ligand_rmsf',
                    'ref': None,
                    'align_asl': ligand_asl,
                    'calc_asl': None,
                    'calculation_type': 'rmsf'
                })
    else:
        ligand_st = structure.Structure.read(str(ligand_mae))
        ligand_asl = '( ' + ' or '.join([
            '(r.ptype {} and c. "{}")'.format(res.pdbres, res.chain)
            for res in ligand_st.residue
        ]) + ' ) and not a.element H'
        calculation_param.append({
            'name': 'ligand_rmsf',
            'ref': None,
            'align_asl': ligand_asl,
            'calc_asl': None,
            'calculation_type': 'rmsf'
        })

    out_queue = multiprocessing.Queue()
    if NPROC != len(calculation_param):
        nproc = len(calculation_param)
    else:
        nproc = NPROC
    logger.info('Performing {} rms calculations, using {} cores.'.format(
        len(calculation_param), nproc))

    calculation_param = np.array_split(calculation_param, nproc)
    workers = []
    for i, params in enumerate(calculation_param):
        workers.append(RMS(cms_file, trj_dir, out_queue, params=params))
        logger.info('starting subprocess {}'.format(i))
        workers[-1].start()
    for i in range(nproc):
        results.append(out_queue.get())
    for w in workers:
        w.join()
    with open(outfile, 'w') as f:
        json.dump(results, f)

    transformer_dict = {
        'structure': {
            'parent_structure_id': structure_dict['structure']['structure_id']
        },
        'files': {
            'trj_rms': outfile
        },
        'custom': structure_dict['custom']
    }
    if fork is not None:
        logger.info('Forking pipeline: ' + ' '.join(fork))
        transformer_dict['control'] = {'forks': fork}

    yield transformer_dict
Beispiel #11
0
def _process(structure_dict):
    """

    :param structure_dict:
    :return:
    """
    t = time.time()

    fork = None
    # Check if transformers is called as part of a pipeline
    if 'pipeline' in structure_dict['custom']:
        pipeline = structure_dict['custom']['pipeline']
        fork = [pipeline[0], ]
        if len(pipeline) == 1:
            del (structure_dict['custom']['pipeline'])
        else:
            structure_dict['custom']['pipeline'] = pipeline[1:]

    structure_code = structure_dict['structure']['code']

    logger.info('Load simulation files')

    cms_file = structure_dict['files']['desmond_cms']

    msys_model, cms_model = topo.read_cms(str(cms_file))

    trjtar = structure_dict['files']['desmond_trjtar']

    # If run from command line it does not make sense to provide a tarfile
    if os.path.isdir(trjtar):
        trj_dir = trjtar
    elif tarfile.is_tarfile(trjtar):
        with tarfile.open(name=trjtar, mode='r:gz') as tfile:
            tfile.extractall()
            trj_dir = tfile.getnames()[0]
    else:
        raise RuntimeError('trjtar is neither a directory nor a tarfile')

    frame_list = traj.read_traj(str(trj_dir))
    frame_list = [frame_list[i] for i in range(0, len(frame_list), STEP)]

    logger.info('Calculating torsion angles')

    cms_model = set_original_atom_index(cms_model)
    ligand_ct = cms_model.extract(evaluate_asl(cms_model, 'ligand'), copy_props=True)

    torsion_list = get_hetero_torsion_atoms(ligand_ct, element_priority=ELEMENT_PRIORITY)
    torsion_list.extend(get_protein_torsion_atoms(cms_model))

    analyzers = []

    torsion_ids = pd.DataFrame(columns=['index', 'aid1', 'aid2', 'aid3', 'aid4'])
    torsion_ids.set_index('index', inplace=True)

    for i, atom_set in enumerate(torsion_list):
        atom_set = list(map(get_original_atom_index, atom_set))
        analyzers.append(Torsion(msys_model, cms_model, *atom_set))
        torsion_ids.loc[i, ['aid1', 'aid2', 'aid3', 'aid4']] = atom_set

    results = analyze(frame_list, *analyzers,
                      **{"progress_feedback": functools.partial(print_iframe, logger=logger)})

    out_arch = '{}_torsion.tar.gz'.format(structure_code)
    with tarfile.open(out_arch, 'w:gz') as tar:
        torsion_ids.to_csv('torsion_ids.csv', sep=',')
        tar.add('torsion_ids.csv')
        for i, timeseries in enumerate(results):
            fname = 'torsion_{}.csv'.format(i)
            np.savetxt(fname, timeseries, delimiter=',')
            tar.add(fname)

    logger.info('Calculated torsion angles in {:.0f} seconds'.format(time.time() - t))
    # Return structure dict
    transformer_dict = {
        'structure': {
            'parent_structure_id':
                structure_dict['structure']['structure_id'],
            'searchable': False
        },
        'files': {'trj_torsion': out_arch},
        'custom': structure_dict['custom']
    }
    if fork is not None:
        logger.info('Forking pipeline: ' + ' '.join(fork))
        transformer_dict['control'] = {'forks': fork}
    yield transformer_dict
Beispiel #12
0
def _process(structure_dict):
    """
    DocString
    :param structure_dict:
    :return:
    """

    fork = None
    # Check if transformers is called as part of a pipeline
    if 'pipeline' in structure_dict['custom']:
        pipeline = structure_dict['custom']['pipeline']
        fork = [
            pipeline[0],
        ]
        if len(pipeline) == 1:
            del (structure_dict['custom']['pipeline'])
        else:
            structure_dict['custom']['pipeline'] = pipeline[1:]

    structure_code = structure_dict['structure']['code']

    outname = '{}_trj_nonbonded'.format(structure_code)
    outfile = outname + '.json'
    outfile_raw = outname + '.tar.bz2'

    # desmond_cms file
    cmsfile = structure_dict['files']['desmond_cms']
    msys_model, cms_model = topo.read_cms(str(cmsfile))

    # desmond_cfg file
    cfgfile = structure_dict['files']['desmond_cfg']

    # desmond frame archive
    trjtar = structure_dict['files']['desmond_trjtar']

    # If run from command line it does not make sense to provide a tarfile
    if os.path.isdir(trjtar):
        trj_dir = trjtar
    elif tarfile.is_tarfile(trjtar):
        with tarfile.open(name=trjtar, mode='r:gz') as tfile:
            tfile.extractall()
            logger.info('extracting frameset')
            trj_dir = tfile.getnames()[0]
    else:
        raise RuntimeError('trjtar is neither a directory nor a tarfile')

    logger.info('creating atomgroups')

    _id, atom_groups, nonbonded_dict, structure_dict, fork = assign_atomgroups(
        structure_dict, cms_model, fork)

    logger.info('Finding free host')
    logger.debug('Hosts: ' + ', '.join(HOSTS))
    host = get_gpu_host(HOSTS)

    logger.info('Running desmond job on: {}'.format(host))
    vrun_obj = VRUN(cms_model, trj_dir, cfgfile, atom_groups)
    energy_group_file = vrun_obj.calculate_energy(outname, host=host)

    if RAW:
        # Write raw output
        with open('{}_atom_groups.json'.format(_id), 'w') as f:
            json.dump(atom_groups, f)
        nonbonded_raw = structure_dict['files'].get('desmond_nonbonded_raw')
        if nonbonded_raw is None:
            with tarfile.open(outfile_raw, 'w:bz2') as tar:
                for fn in [
                        energy_group_file, '{}_atom_groups.json'.format(_id)
                ]:
                    tar.add(fn)
        else:
            with tarfile.open(nonbonded_raw, 'r:bz2') as tar:
                tar.extractall()
                members = tar.getmembers()
            with tarfile.open(outfile_raw, 'w:bz2') as tar:
                for member in members:
                    filename = member.name
                    if filename.split('_')[0] not in (
                            'custom', 'default'):  # Backward Compatibility
                        logger.warning('Updating file names to new version')
                        os.rename(filename, 'default_' + filename)
                        filename = 'default_' + filename
                    tar.add(filename)
                for fn in [
                        energy_group_file, '{}_atom_groups.json'.format(_id)
                ]:
                    tar.add(fn)

    # Get time and energy components
    sim_time, component_dict = parse_output(energy_group_file,
                                            len(atom_groups),
                                            ENERGY_COMPONENTS)
    # calculate mean and error
    results = {}
    logger.info('Calculating average potential')
    for comp in ENERGY_COMPONENTS:
        pair_dict = component_dict[comp]
        data_dict = {}
        results[comp] = {}
        results[comp]['keys'] = []
        results[comp]['mean_potential'] = []
        results[comp]['error'] = []
        for pair, energy in pair_dict.items():
            mean_potential = np.mean(energy)
            # skip zero and near 0 potential (The latter sometimes causes overflow issues during error calc.)
            if np.abs(mean_potential) < 1e-6:
                continue
            results[comp]['keys'].append(list(pair))
            results[comp]['mean_potential'].append(mean_potential)
            data_dict[pair] = energy  # Only store non-zero energies
        # Calculate the error separately over multiple processes
        nproc = dynamic_cpu_assignment(NPROC)
        logger.debug('Calculating error for {}'.format(comp))
        error_dict = get_error(data_dict, nproc)
        for k in results[comp]['keys']:
            results[comp]['error'].append(error_dict[tuple(k)])

    nonbonded_dict[_id]['energy_components'] = ENERGY_COMPONENTS
    nonbonded_dict[_id]['time'] = sim_time
    nonbonded_dict[_id]['results'] = results

    with open(outfile, 'w') as f:
        json.dump(nonbonded_dict, f)

    transformer_dict = {
        'structure': {
            'parent_structure_id': structure_dict['structure']['structure_id']
        },
        'files': {
            'desmond_nonbonded': outfile,
            'desmond_nonbonded_raw': outfile_raw
        },
        'custom': structure_dict['custom'],
    }
    if fork is not None:
        logger.info('Forking pipeline: ' + ' '.join(fork))
        transformer_dict['control'] = {'forks': fork}
    yield transformer_dict
Beispiel #13
0
def preprocess_elec(api, dataset, data_type='default', no_ligand=False):
    """
    Get pairwise electrostatic interactions
    :param api:
    :param dataset:
    :param data_type:
    :param no_ligand: If true do not rename the ligand residue
    :return:
    """

    if data_type is None:
        raise ValueError(
            'elec data_type not set. Specify elec data_type or use default data_types'
        )

    energy_component = 'nonbonded_elec'

    if api is None:
        dataset.dropna(axis=0,
                       subset=['desmond_cms', 'desmond_nonbonded'],
                       inplace=True)

    elec_dict = dict((index, {}) for index in dataset.index)

    for index, stid in zip(dataset.index, dataset.structure_id):
        # Load data
        if api is None:
            msys_model, cms_model = topo.read_cms(dataset.loc[index,
                                                              'desmond_cms'])
            with open(dataset.loc[index, 'desmond_nonbonded'], 'r') as fh:
                nonbonded_dict = json.load(fh)
        else:
            # Load desmond_cms
            msys_model, cms_model = get_desmond_cms(api, stid)
            # Load desmond_nonbonded
            nonbonded_dict = get_desmond_nonbonded(api, stid)

        # Check if data_type in nonbonded_dict
        if data_type not in nonbonded_dict:
            if data_type != 'default':
                raise ValueError('{} data type not calculated for {}'.format(
                    data_type, stid))
            elif 'group_ids' in nonbonded_dict:
                logger.warning('{} uses outdated file format'.format(stid))
            else:
                raise ValueError('{} data type not claculated for {}'.format(
                    data_type, stid))
        else:
            nonbonded_dict = nonbonded_dict[data_type]

        # Determine Ligand residue
        if no_ligand:
            ligand_resid = None
        elif 'ligand_resnum' not in dataset.columns and 'ligand_chain' not in dataset.columns:
            ligands = find_ligands(st=cms_model)
            assert (len(ligands) == 1)
            ligand_resid = None
            for atm in ligands[0].st.atom:
                if ligand_resid is None:
                    ligand_resid = (atm.resnum, atm.chain.strip())
                else:
                    if (atm.resnum, atm.chain.strip()) != ligand_resid:
                        raise ValueError(
                            'Structure {}\nFound multiple ligand residues.'.
                            format(stid))
        else:
            ligand_resid = (int(dataset.loc[index, 'ligand_resnum']),
                            dataset.loc[index, 'ligand_chain'])

        # Map group ids
        group_ids = nonbonded_dict['group_ids']
        id2resid = {}

        if all([type(gid) == list for gid in group_ids]):
            for i, resid in enumerate(map(tuple, group_ids)):
                if resid == ligand_resid:
                    id2resid[i] = 'LIG'
                else:
                    id2resid[i] = '{}:{}'.format(*resid)
        else:
            logger.warning('Custom group_ids found')
            for i, group_id in enumerate(group_ids):
                id2resid[i] = group_id

        elec_keys = list(
            map(tuple, nonbonded_dict['results'][energy_component]['keys']))
        elec_means = nonbonded_dict['results'][energy_component][
            'mean_potential']
        for key, mean in zip(elec_keys, elec_means):
            ri, rj = list(map(int, key))
            pair = sorted([
                id2resid[ri], id2resid[rj]
            ])  # Sort ensure common pairs are always assigned the same pair_id
            pair_id = '{} - {}'.format(*pair)
            if pair_id not in elec_dict[index]:
                elec_dict[index][pair_id] = mean

    df_elec = pd.DataFrame(elec_dict).T
    # replace NaNs with 0
    df_elec.fillna(0, inplace=True)
    df_elec.index = dataset.structure_id
    return df_elec