コード例 #1
0
ファイル: test_diff.py プロジェクト: arosen93/rASE
def test_template_functions():
    """Test functions used in the template module."""
    num = 1.55749
    rnum = [prec_round(num, i) for i in range(1, 6)]
    assert rnum == [1.6, 1.56, 1.557, 1.5575, 1.55749]
    assert slice_split('a@1:3:1') == ('a', slice(1, 3, 1))

    sym = 'H'
    num = sym2num[sym]
    mf = MapFormatter().format
    sym2 = mf('{:h}', num)
    assert sym == sym2
コード例 #2
0
def test_template_functions():
    """Test functions used in the template module."""
    num = 1.55749
    rnum = [prec_round(num, i) for i in range(1, 6)]
    assert rnum == [1.6, 1.56, 1.557, 1.5575, 1.55749]
    blarray = [4, 3, 1, 0,
               2] == sort2rank([3, 2, 4, 1,
                                0])  # sort2rank outputs numpy array
    assert blarray.all()
    assert slice_split('a@1:3:1') == ('a', slice(1, 3, 1))

    sym = 'H'
    num = sym2num[sym]
    mf = MapFormatter().format
    sym2 = mf('{:h}', num)
    assert sym == sym2
コード例 #3
0
ファイル: diff.py プロジェクト: maurergroup/ase_local
    def diff(args, out):
        from ase.cli.template import (
            Table,
            TableFormat,
            slice_split,
            field_specs_on_conditions,
            summary_functions_on_conditions,
            rmsd,
            energy_delta)

        if args.template is None:
            field_specs = field_specs_on_conditions(
                args.calculator_outputs, args.rank_order)
        else:
            field_specs = args.template.split(',')
            if not args.calculator_outputs:
                for field_spec in field_specs:
                    if 'f' in field_spec:
                        raise CLIError(
                            "field requiring calculation outputs "
                            "without --calculator-outputs")

        if args.summary_functions is None:
            summary_functions = summary_functions_on_conditions(
                args.calculator_outputs)
        else:
            summary_functions_dct = {
                'rmsd': rmsd,
                'dE': energy_delta}
            summary_functions = args.summary_functions.split(',')
            if not args.calculator_outputs:
                for sf in summary_functions:
                    if sf == 'dE':
                        raise CLIError(
                            "summary function requiring calculation outputs "
                            "without --calculator-outputs")
            summary_functions = [summary_functions_dct[i]
                                 for i in summary_functions]

        have_two_files = len(args.file) == 2
        file1 = args.file[0]
        actual_filename, index = slice_split(file1)
        atoms1 = read(actual_filename, index)
        natoms1 = len(atoms1)

        if have_two_files:
            if args.file[1] == '-':
                atoms2 = atoms1

                def header_fmt(c):
                    return 'image # {}'.format(c)
            else:
                file2 = args.file[1]
                actual_filename, index = slice_split(file2)
                atoms2 = read(actual_filename, index)
                natoms2 = len(atoms2)

                same_length = natoms1 == natoms2
                one_l_one = natoms1 == 1 or natoms2 == 1

                if not same_length and not one_l_one:
                    raise CLIError(
                        "Trajectory files are not the same length "
                        "and both > 1\n{}!={}".format(
                            natoms1, natoms2))
                elif not same_length and one_l_one:
                    print(
                        "One file contains one image "
                        "and the other multiple images,\n"
                        "assuming you want to compare all images "
                        "with one reference image")
                    if natoms1 > natoms2:
                        atoms2 = natoms1 * atoms2
                    else:
                        atoms1 = natoms2 * atoms1

                    def header_fmt(c):
                        return 'sys-ref image # {}'.format(c)
                else:
                    def header_fmt(c):
                        return 'sys2-sys1 image # {}'.format(c)
        else:
            atoms2 = atoms1.copy()
            atoms1 = atoms1[:-1]
            atoms2 = atoms2[1:]
            natoms2 = natoms1 = natoms1 - 1

            def header_fmt(c):
                return 'images {}-{}'.format(c + 1, c)

        natoms = natoms1  # = natoms2

        output = ''
        tableformat = TableFormat(precision=args.precision,
                                  columnwidth=7 + args.precision)

        table = Table(
            field_specs,
            max_lines=args.max_lines,
            tableformat=tableformat,
            summary_functions=summary_functions)

        for counter in range(natoms):
            table.title = header_fmt(counter)
            output += table.make(atoms1[counter],
                                 atoms2[counter], csv=args.as_csv) + '\n'
        print(output, file=out)