Ejemplo n.º 1
0
def test_template_classes(traj):
    prec = 4
    tableformat = TableFormat(precision=prec, representation='f', midrule='|')
    table = Table(field_specs=('dx', 'dy', 'dz'), tableformat=tableformat)
    traj = read(str(traj), ':')
    table_out = table.make(traj[0], traj[1]).split('\n')
    for counter, row in enumerate(table_out):
        if '|' in row:
            break

    row = table_out[counter + 2]

    assert 'E' not in table_out[counter + 2]

    row = re.sub(r'\s+', ',', table_out[counter + 2]).split(',')[1:-1]
    assert len(row[0]) >= prec
Ejemplo n.º 2
0
    def diff(args, out):
        from ase.cli.template import (
            Table,
            TableFormat,
            slice_split,
            field_specs_on_conditions,
            summary_functions_on_conditions,
            rmsd,
            energy_delta)

        if args.template is None:
            field_specs = field_specs_on_conditions(
                args.calculator_outputs, args.rank_order)
        else:
            field_specs = args.template.split(',')
            if not args.calculator_outputs:
                for field_spec in field_specs:
                    if 'f' in field_spec:
                        raise CLIError(
                            "field requiring calculation outputs "
                            "without --calculator-outputs")

        if args.summary_functions is None:
            summary_functions = summary_functions_on_conditions(
                args.calculator_outputs)
        else:
            summary_functions_dct = {
                'rmsd': rmsd,
                'dE': energy_delta}
            summary_functions = args.summary_functions.split(',')
            if not args.calculator_outputs:
                for sf in summary_functions:
                    if sf == 'dE':
                        raise CLIError(
                            "summary function requiring calculation outputs "
                            "without --calculator-outputs")
            summary_functions = [summary_functions_dct[i]
                                 for i in summary_functions]

        have_two_files = len(args.file) == 2
        file1 = args.file[0]
        actual_filename, index = slice_split(file1)
        atoms1 = read(actual_filename, index)
        natoms1 = len(atoms1)

        if have_two_files:
            if args.file[1] == '-':
                atoms2 = atoms1

                def header_fmt(c):
                    return 'image # {}'.format(c)
            else:
                file2 = args.file[1]
                actual_filename, index = slice_split(file2)
                atoms2 = read(actual_filename, index)
                natoms2 = len(atoms2)

                same_length = natoms1 == natoms2
                one_l_one = natoms1 == 1 or natoms2 == 1

                if not same_length and not one_l_one:
                    raise CLIError(
                        "Trajectory files are not the same length "
                        "and both > 1\n{}!={}".format(
                            natoms1, natoms2))
                elif not same_length and one_l_one:
                    print(
                        "One file contains one image "
                        "and the other multiple images,\n"
                        "assuming you want to compare all images "
                        "with one reference image")
                    if natoms1 > natoms2:
                        atoms2 = natoms1 * atoms2
                    else:
                        atoms1 = natoms2 * atoms1

                    def header_fmt(c):
                        return 'sys-ref image # {}'.format(c)
                else:
                    def header_fmt(c):
                        return 'sys2-sys1 image # {}'.format(c)
        else:
            atoms2 = atoms1.copy()
            atoms1 = atoms1[:-1]
            atoms2 = atoms2[1:]
            natoms2 = natoms1 = natoms1 - 1

            def header_fmt(c):
                return 'images {}-{}'.format(c + 1, c)

        natoms = natoms1  # = natoms2

        output = ''
        tableformat = TableFormat(precision=args.precision,
                                  columnwidth=7 + args.precision)

        table = Table(
            field_specs,
            max_lines=args.max_lines,
            tableformat=tableformat,
            summary_functions=summary_functions)

        for counter in range(natoms):
            table.title = header_fmt(counter)
            output += table.make(atoms1[counter],
                                 atoms2[counter], csv=args.as_csv) + '\n'
        print(output, file=out)