コード例 #1
0
def test_read_write_versioning(lgf_data_dir):
    repo0 = Repository.create_from(str(lgf_data_dir), versioning=True)

    tmp = BytesIO()
    repo0.write(tmp)

    repo1 = Repository.read(tmp, versioning=True)

    assert repo0._Repository__min_shell == repo1._Repository__min_shell
    assert repo0._Repository__max_shell == repo1._Repository__max_shell
    assert repo0._Repository__traceable == repo1._Repository__traceable
    assert repo0._Repository__versioning == repo1._Repository__versioning

    assert repo0.charges_iacm == repo1.charges_iacm
    assert repo0.charges_elem == repo1.charges_elem
    assert not hasattr(repo0, 'iso_iacm')
    assert not hasattr(repo0, 'iso_elem')
    assert not hasattr(repo1, 'iso_iacm')
    assert not hasattr(repo1, 'iso_elem')

    assert isinstance(
        repo1.charges_iacm[1]['c18208da9e290c6faf8a0c58017d24d9'],
        _VersioningList)
    assert isinstance(
        repo1.charges_iacm[3]['76198d87470cc1b2f871da60449146bc'],
        _VersioningList)
    assert isinstance(
        repo1.charges_elem[1]['92ed00c54b2190be94748bee34b22847'],
        _VersioningList)
    assert isinstance(
        repo1.charges_elem[3]['17ac3199bf634022485c145821f358d5'],
        _VersioningList)
コード例 #2
0
def test_read_write_traceable(lgf_data_dir):
    repo0 = Repository.create_from(str(lgf_data_dir), traceable=True)

    tmp = BytesIO()
    repo0.write(tmp)

    repo1 = Repository.read(tmp)

    assert repo0._Repository__min_shell == repo1._Repository__min_shell
    assert repo0._Repository__max_shell == repo1._Repository__max_shell
    assert repo0._Repository__traceable == repo1._Repository__traceable
    assert repo0._Repository__versioning == repo1._Repository__versioning

    # msgpack writes tuples as lists. this is ok for our purposes, but we need to take care of that in testing!
    assert repo0.charges_iacm.keys() == repo1.charges_iacm.keys()
    for shell_size in repo0.charges_iacm.keys():
        assert repo0.charges_iacm[shell_size].keys(
        ) == repo1.charges_iacm[shell_size].keys()
        for key in repo0.charges_iacm[shell_size].keys():
            assert repo0.charges_iacm[shell_size][key] == [
                tuple(x) for x in repo0.charges_iacm[shell_size][key]
            ]

    assert repo0.charges_elem.keys() == repo1.charges_elem.keys()
    for shell_size in repo0.charges_elem.keys():
        assert repo0.charges_elem[shell_size].keys(
        ) == repo1.charges_elem[shell_size].keys()
        for key in repo0.charges_elem[shell_size].keys():
            assert repo0.charges_elem[shell_size][key] == [
                tuple(x) for x in repo0.charges_elem[shell_size][key]
            ]

    assert repo0.iso_iacm == repo1.iso_iacm
    assert repo0.iso_elem == repo1.iso_elem
コード例 #3
0
def test_create_traceable_from_dir(lgf_data_dir):
    repo = Repository.create_from(str(lgf_data_dir), traceable=True)
    assert repo._Repository__min_shell == 1
    assert repo._Repository__max_shell == 7
    assert repo._Repository__traceable == True
    assert repo._Repository__versioning == False

    assert len(repo.charges_iacm) == 7
    assert len(repo.charges_iacm[1]) == 10
    assert len(repo.charges_iacm[2]) == 14
    assert len(repo.charges_iacm[3]) == 15

    assert repo.charges_iacm[1]['c18208da9e290c6faf8a0c58017d24d9'] == [
        (-0.516, 15610, 2)
    ]
    assert repo.charges_iacm[3]['76198d87470cc1b2f871da60449146bc'] == [
        (0.077, 1204, 1), (0.077, 1204, 4), (0.077, 1204, 5)
    ]

    assert len(repo.charges_elem) == 7
    assert len(repo.charges_elem[1]) == 9
    assert len(repo.charges_elem[2]) == 14
    assert len(repo.charges_elem[3]) == 15

    assert repo.charges_elem[1]['92ed00c54b2190be94748bee34b22847'] == [
        (-0.516, 15610, 2)
    ]
    assert repo.charges_elem[3]['17ac3199bf634022485c145821f358d5'] == [
        (0.077, 1204, 1), (0.077, 1204, 4), (0.077, 1204, 5)
    ]

    assert hasattr(repo, 'iso_iacm')
    assert hasattr(repo, 'iso_elem')
    assert len(repo.iso_iacm) == 2
    assert len(repo.iso_elem) == 2
コード例 #4
0
    def test_charge_molecule(self):
        """Test case for charge_molecule

        Submit a molecule for charging
        """
        cur_dir = Path(__file__).parent
        test_data = cur_dir.parents[1] / 'charge' / 'test' / 'sample_lgf'
        repo = Repository.create_from(str(test_data))
        with tempfile.TemporaryDirectory() as tmpdir:
            repo_location = str(Path(tmpdir) / 'repo.zip')
            repo.write(repo_location)
            charge_server.init(repo_location)

        test_lgf = ('@nodes\n'
                    'label\tlabel2\tatomType\tinitColor\t\n'
                    '1\tC1\t12\t0\t\n'
                    '2\tH1\t20\t0\t\n'
                    '3\tH2\t20\t0\t\n'
                    '4\tH3\t20\t0\t\n'
                    '5\tH4\t20\t0\t\n'
                    '@edges\n'
                    '\t\tlabel\t\n'
                    '1\t2\t0\t\n'
                    '1\t3\t1\t\n'
                    '1\t4\t2\t\n'
                    '1\t5\t3\t\n')
        query_string = [('total_charge', 0)]
        response = self.client.open('/charge_assign',
                                    method='POST',
                                    data=test_lgf,
                                    content_type='text/plain',
                                    query_string=query_string)
        self.assert200(response,
                       'Response body is : ' + response.data.decode('utf-8'))
コード例 #5
0
def main() -> None:
    args = get_args()

    repo = Repository.create_from(args.input,
                                  max_shell=args.max_shell,
                                  traceable=args.traceable)
    repo.write(args.repository)
コード例 #6
0
def test_create_empty_traceable():
    repo = Repository(traceable=True)
    assert repo._Repository__min_shell == 1
    assert repo._Repository__max_shell == 7
    assert repo._Repository__traceable == True
    assert repo._Repository__versioning == False
    assert len(repo.charges_iacm) == 0
    assert len(repo.charges_elem) == 0
    assert len(repo.iso_iacm) == 0
    assert len(repo.iso_elem) == 0
コード例 #7
0
def test_read_write(lgf_data_dir):
    repo0 = Repository.create_from(str(lgf_data_dir))

    tmp = BytesIO()
    repo0.write(tmp)

    repo1 = Repository.read(tmp)

    assert repo0._Repository__min_shell == repo1._Repository__min_shell
    assert repo0._Repository__max_shell == repo1._Repository__max_shell
    assert repo0._Repository__traceable == repo1._Repository__traceable
    assert repo0._Repository__versioning == repo1._Repository__versioning

    assert repo0.charges_iacm == repo1.charges_iacm
    assert repo0.charges_elem == repo1.charges_elem
    assert not hasattr(repo0, 'iso_iacm')
    assert not hasattr(repo0, 'iso_elem')
    assert not hasattr(repo1, 'iso_iacm')
    assert not hasattr(repo1, 'iso_elem')
コード例 #8
0
def test_create_empty():
    repo = Repository()
    assert repo._Repository__min_shell == 1
    assert repo._Repository__max_shell == 7
    assert repo._Repository__traceable == False
    assert repo._Repository__versioning == False
    assert len(repo.charges_iacm) == 0
    assert len(repo.charges_elem) == 0
    assert not hasattr(repo, 'iso_iacm')
    assert not hasattr(repo, 'iso_elem')
コード例 #9
0
def test_create_versioning_from_dir(lgf_data_dir):
    repo = Repository.create_from(str(lgf_data_dir), versioning=True)
    assert repo._Repository__min_shell == 1
    assert repo._Repository__max_shell == 7
    assert repo._Repository__traceable == False
    assert repo._Repository__versioning == True

    assert len(repo.charges_iacm) == 7
    assert len(repo.charges_iacm[1]) == 10
    assert len(repo.charges_iacm[2]) == 14
    assert len(repo.charges_iacm[3]) == 15

    assert isinstance(repo.charges_iacm[1]['c18208da9e290c6faf8a0c58017d24d9'],
                      _VersioningList)
    assert isinstance(repo.charges_iacm[3]['76198d87470cc1b2f871da60449146bc'],
                      _VersioningList)

    assert repo.charges_iacm[1]['c18208da9e290c6faf8a0c58017d24d9'] == [-0.516]
    assert repo.charges_iacm[3]['76198d87470cc1b2f871da60449146bc'] == [
        0.077, 0.077, 0.077
    ]

    assert len(repo.charges_elem) == 7
    assert len(repo.charges_elem[1]) == 9
    assert len(repo.charges_elem[2]) == 14
    assert len(repo.charges_elem[3]) == 15

    assert isinstance(repo.charges_elem[1]['92ed00c54b2190be94748bee34b22847'],
                      _VersioningList)
    assert isinstance(repo.charges_elem[3]['17ac3199bf634022485c145821f358d5'],
                      _VersioningList)

    assert repo.charges_elem[1]['92ed00c54b2190be94748bee34b22847'] == [-0.516]
    assert repo.charges_elem[3]['17ac3199bf634022485c145821f358d5'] == [
        0.077, 0.077, 0.077
    ]

    assert not hasattr(repo, 'iso_iacm')
    assert not hasattr(repo, 'iso_elem')

    ov0 = repo.charges_iacm[1]['c18208da9e290c6faf8a0c58017d24d9'].version
    repo.charges_iacm[1]['c18208da9e290c6faf8a0c58017d24d9'].append(1)
    ov1 = repo.charges_iacm[1]['c18208da9e290c6faf8a0c58017d24d9'].version
    repo.charges_iacm[1]['c18208da9e290c6faf8a0c58017d24d9'] += [2]
    ov2 = repo.charges_iacm[1]['c18208da9e290c6faf8a0c58017d24d9'].version
    del repo.charges_iacm[1]['c18208da9e290c6faf8a0c58017d24d9'][-1]
    ov3 = repo.charges_iacm[1]['c18208da9e290c6faf8a0c58017d24d9'].version
    repo.charges_iacm[1]['c18208da9e290c6faf8a0c58017d24d9'].remove(1)
    ov4 = repo.charges_iacm[1]['c18208da9e290c6faf8a0c58017d24d9'].version

    assert len({ov0, ov1, ov2, ov3, ov4}) == 5
コード例 #10
0
def test_remove_from(lgf_data_dir):
    repo = Repository.create_from(str(lgf_data_dir))
    repo.remove_from(lgf_data_dir)

    assert repo._Repository__min_shell == 1
    assert repo._Repository__max_shell == 7
    assert repo._Repository__traceable == False
    assert repo._Repository__versioning == False

    assert len(repo.charges_iacm) == 0
    assert len(repo.charges_elem) == 0

    assert not hasattr(repo, 'iso_iacm')
    assert not hasattr(repo, 'iso_elem')
コード例 #11
0
ファイル: charge_server.py プロジェクト: RicoGe/charge_assign
def init(repo_location: Optional[str] = None) -> None:
    """Create the charger from a repository.

    Gets the repository location from the command line if none is
    given.

    Args:
        repo_location: The path to the repository to use.

    """
    global _charger
    if not repo_location:
        repo_location = parse_arguments()
    repo = Repository.read(repo_location)
    _charger = CDPCharger(repo, rounding_digits=3)
コード例 #12
0
def test_add_from_versioning(lgf_data_dir):
    repo = Repository.create_from(str(lgf_data_dir), versioning=True)
    repo.add_from(lgf_data_dir)

    assert repo._Repository__min_shell == 1
    assert repo._Repository__max_shell == 7
    assert repo._Repository__traceable == False
    assert repo._Repository__versioning == True

    assert len(repo.charges_iacm) == 7
    assert len(repo.charges_iacm[1]) == 10
    assert len(repo.charges_iacm[2]) == 14
    assert len(repo.charges_iacm[3]) == 15

    assert repo.charges_iacm[1]['c18208da9e290c6faf8a0c58017d24d9'] == [
        -0.516, -0.516
    ]
    assert repo.charges_iacm[3]['76198d87470cc1b2f871da60449146bc'] == [
        0.077, 0.077, 0.077, 0.077, 0.077, 0.077
    ]

    assert len(repo.charges_elem) == 7
    assert len(repo.charges_elem[1]) == 9
    assert len(repo.charges_elem[2]) == 14
    assert len(repo.charges_elem[3]) == 15

    assert repo.charges_elem[1]['92ed00c54b2190be94748bee34b22847'] == [
        -0.516, -0.516
    ]
    assert repo.charges_elem[3]['17ac3199bf634022485c145821f358d5'] == [
        0.077, 0.077, 0.077, 0.077, 0.077, 0.077
    ]

    assert not hasattr(repo, 'iso_iacm')
    assert not hasattr(repo, 'iso_elem')

    assert isinstance(repo.charges_iacm[1]['c18208da9e290c6faf8a0c58017d24d9'],
                      _VersioningList)
    assert isinstance(repo.charges_iacm[3]['76198d87470cc1b2f871da60449146bc'],
                      _VersioningList)
    assert isinstance(repo.charges_elem[1]['92ed00c54b2190be94748bee34b22847'],
                      _VersioningList)
    assert isinstance(repo.charges_elem[3]['17ac3199bf634022485c145821f358d5'],
                      _VersioningList)
コード例 #13
0
ファイル: validation.py プロジェクト: RicoGe/charge_assign
def cross_validate_methods(
        data_location: str,
        data_type: IOType = IOType.LGF,
        min_shell: Optional[int] = None,
        max_shell: Optional[int] = None
        ) -> Tuple[Dict[str, Dict[bool, float]], Dict[str, Dict[bool, float]]]:
    """Cross-validates all methods on the given molecule data.

    Args:
        data_location: Path to the directory with the molecule data.
        data_type: Format of the molecule data to expect.
        min_shell: Smallest shell size to use.
        max_shell: Largest shell size to use.

    Returns:
        Dictionaries keyed by charger name and whether IACM atoms were \
                used, containing the mean absolute error and the mean \
                square error respectively.
    """
    repo = Repository.create_from(data_location, data_type, min_shell, max_shell,
            traceable=True)
    shell = range[max_shell, min_shell-1, -1]

    mean_abs_err = dict()
    mean_sq_err = dict()

    for charger_type in [MeanCharger, MedianCharger, ModeCharger, ILPCharger, DPCharger, CDPCharger, SymmetricILPCharger, SymmetricILPCharger, SymmetricCDPCharger]:
        charger_name = charger_type.__name__
        mean_abs_err[charger_name] = dict()
        mean_sq_err[charger_name] = dict()
        for iacm in [True, False]:
            mae, mse = (
                    cross_validate_molecules(
                    charger_name, iacm, data_location, data_type, shell, repo))
            mean_abs_err[charger_name][iacm] = mae
            mean_sq_err[charger_name][iacm] = mse

    return mean_abs_err, mean_sq_err
コード例 #14
0

if __name__ == '__main__':
    test_data_dir = os.path.realpath(
        os.path.join(__file__, '..', 'cross_validation_data'))

    if len(sys.argv) != 3:
        print('Usage: {} bucket num_buckets'.format(sys.argv[0]))
        print('Got {} arguments.'.format(len(sys.argv) - 1))
        quit()

    bucket = int(sys.argv[1])
    num_buckets = int(sys.argv[2])

    repo_file = 'cross_validation_repository.zip'
    repo = Repository.read(repo_file)

    cross_validate('MeanCharger', False, 3, test_data_dir, repo, bucket,
                   num_buckets)
    cross_validate('MeanCharger', True, 3, test_data_dir, repo, bucket,
                   num_buckets)
    cross_validate('MedianCharger', False, 3, test_data_dir, repo, bucket,
                   num_buckets)
    cross_validate('MedianCharger', True, 3, test_data_dir, repo, bucket,
                   num_buckets)
    cross_validate('ModeCharger', False, 3, test_data_dir, repo, bucket,
                   num_buckets)
    cross_validate('ModeCharger', True, 3, test_data_dir, repo, bucket,
                   num_buckets)
    cross_validate('ILPCharger', False, 3, test_data_dir, repo, bucket,
                   num_buckets)
コード例 #15
0
import os

from charge.repository import Repository

if __name__ == '__main__':

    test_data_dir = os.path.realpath(
        os.path.join(__file__, '..', 'cross_validation_data'))

    out_file = 'cross_validation_repository.zip'

    repo = Repository.create_from(test_data_dir,
                                  min_shell=0,
                                  max_shell=3,
                                  traceable=True)

    repo.write(out_file)
コード例 #16
0
ファイル: validation.py プロジェクト: RicoGe/charge_assign
def cross_validate_molecules(
        charger_type: str,
        iacm: bool,
        data_location: str,
        data_type: IOType = IOType.LGF,
        shell: Union[None, int, Iterable[int]] = None,
        repo: Optional[Repository] = None,
        bucket: int = 0,
        num_buckets: int = 1
        ) -> ValidationReport:
    """Cross-validates a particular method on the given molecule data.

    Runs through all the molecules in the repository, and for each, \
    predicts charges using the given charger type and from the rest of \
    the molecules in the repository.

    If iacm is False, the test molecule is stripped of its charge data \
    and its IACM atom types, leaving only plain elements. It is then \
    matched first against the IACM side of the repository, and if that \
    yields no charges, against the plain element side of the repository.

    If iacm is True, the test molecule is stripped of charges but keeps \
    its IACM atom types. It is then matched against the IACM side of the \
    repository, and if no matches are found, its plain elements are \
    matched against the plain element side.

    If bucket and num_buckets are specified, then this will only run \
    the cross-validation if (molid % num_buckets) == bucket.

    Args:
        charger_type: Name of a Charger-derived class implementing an \
                assignment method.
        iacm: Whether to use IACM or plain element atoms.
        data_location: Path to the directory with the molecule data.
        data_type: Format of the molecule data to expect.
        shell: (List of) shell size(s) to use.
        repo: A Repository with traceable charges.
        bucket: Cross-validate for this bucket.
        num_buckets: Total number of buckets that will run.

    Returns:
        A dict containing AtomReports per element category, and a
        MoleculeReport. Keyed by category name, and 'Molecule' for
        the per-molecule statistics.
    """
    if shell is None:
        min_shell, max_shell = None, None
        wanted_shells = None
    else:
        if isinstance(shell, int):
            shell = [shell]
        min_shell, max_shell = min(shell), max(shell)
        wanted_shells = sorted(shell, reverse=True)

    if repo is None:
        if min_shell is not None:
            repo = Repository.create_from(data_location, data_type, min_shell,
                    max_shell, traceable=True)
        else:
            repo = Repository.create_from(data_location, data_type,
                    traceable=True)

    if wanted_shells is None:
        shells = sorted(repo.charges_iacm.keys(), reverse=True)
    else:
        shells = []
        for s in wanted_shells:
            if s not in repo.charges_iacm.keys():
                msg = 'Shell {} will not be used, as it is not in the repository'
                warn(msg.format(s))
            else:
                shells.append(s)

    nauty = Nauty()

    extension = data_type.get_extension()
    molids = [int(fn.replace(extension, ''))
              for fn in os.listdir(data_location)
              if fn.endswith(extension)]

    report = ValidationReport()

    for molid in molids:
        if (molid % num_buckets) == bucket:
            #print('molid: {}'.format(molid))

            mol_path = os.path.join(data_location, '{}{}'.format(molid, extension))
            with open(mol_path, 'r') as f:
                graph = convert_from(f.read(), data_type)

            report += cross_validate_molecule(repo, molid, graph, charger_type, shells, iacm, nauty)

    return report
コード例 #17
0
def test_remove_from_traceable(lgf_data_dir):
    repo = Repository.create_from(str(lgf_data_dir), traceable=True)

    with pytest.raises(ValueError):
        repo.remove_from(lgf_data_dir)
コード例 #18
0
def test_set_shell_sizes():
    repo = Repository(2, 4)
    assert repo._Repository__min_shell == 2
    assert repo._Repository__max_shell == 4