Пример #1
0
def run_scaling(params, experiments, reflections):
    """Run scaling algorithms; stats only, cross validation or standard."""
    if params.stats_only:
        Script.stats_only(reflections, experiments, params)
        sys.exit()

    if params.export_mtz_only:
        Script.export_mtz_only(reflections, experiments, params)
        sys.exit()

    if params.output.delete_integration_shoeboxes:
        for r in reflections:
            del r["shoebox"]

    if params.cross_validation.cross_validation_mode:
        from dials.algorithms.scaling.cross_validation.cross_validate import (
            cross_validate,
        )
        from dials.algorithms.scaling.cross_validation.crossvalidator import (
            DialsScaleCrossValidator,
        )

        cross_validator = DialsScaleCrossValidator(experiments, reflections)
        try:
            cross_validate(params, cross_validator)
        except ValueError as e:
            raise Sorry(e)

        logger.info(
            "Cross validation analysis does not produce scaling output files, rather\n"
            "it gives insight into the dataset. Choose an appropriate parameterisation\n"
            "and rerun scaling without cross_validation_mode.\n"
        )

    else:
        script = Script(params, experiments, reflections)
        # Register the observers at the highest level
        if params.output.html:
            register_default_scaling_observers(script)
        else:
            register_merging_stats_observers(script)
        if params.filtering.method:
            if script.scaler.id_ != "multi":
                raise Sorry(
                    """
Scaling and filtering can only be performed in multi-dataset scaling mode
(not single dataset or scaling against a reference)"""
                )
            register_scale_and_filter_observers(script)
            script.run_scale_and_filter()
            with open(params.filtering.output.scale_and_filter_results, "w") as f:
                json.dump(script.filtering_results.to_dict(), f, indent=2)
        else:
            script.run()
        script.export()
Пример #2
0
def run_scaling(params, experiments, reflections):
    """Run scaling algorithms; cross validation, scaling + filtering or standard.

    Returns:
        experiments: an experiment list with scaled data (if created)
        joint_table: a single reflection table containing scaled data (if created).
    """

    if params.output.delete_integration_shoeboxes:
        for r in reflections:
            del r["shoebox"]

    if params.cross_validation.cross_validation_mode:
        from dials.algorithms.scaling.cross_validation.cross_validate import (
            cross_validate, )
        from dials.algorithms.scaling.cross_validation.crossvalidator import (
            DialsScaleCrossValidator, )

        cross_validator = DialsScaleCrossValidator(experiments, reflections)
        cross_validate(params, cross_validator)

        logger.info(
            "Cross validation analysis does not produce scaling output files, rather\n"
            "it gives insight into the dataset. Choose an appropriate parameterisation\n"
            "and rerun scaling without cross_validation_mode.\n")
        return (None, None)

    else:
        # Register the observers at the highest level
        if params.filtering.method:
            algorithm = ScaleAndFilterAlgorithm(params, experiments,
                                                reflections)
            register_scale_and_filter_observers(algorithm)
        else:
            algorithm = ScalingAlgorithm(params, experiments, reflections)

        if params.output.html:
            register_default_scaling_observers(algorithm)
        else:
            register_merging_stats_observers(algorithm)

        algorithm.run()

        experiments, joint_table = algorithm.finish()

        return experiments, joint_table
Пример #3
0
def test_cross_validate_script():
    """Test the script, mocking the run_script and interpret results calls"""

    param = generated_param()
    crossvalidator = DialsScaleCrossValidator([], [])

    # test expected error raise due to unspecified parameter
    param.cross_validation.cross_validation_mode = "multi"
    with pytest.raises(ValueError):
        cross_validate(param, crossvalidator)

    # test single mode
    param.cross_validation.cross_validation_mode = "single"
    param.cross_validation.nfolds = 2
    fpath = "dials.algorithms.scaling.cross_validation."
    with mock.patch(fpath +
                    "crossvalidator.DialsScaleCrossValidator.run_script"
                    ) as mock_run_script:
        with mock.patch(
                fpath +
                "crossvalidator.DialsScaleCrossValidator.interpret_results"
        ) as mock_interpret:
            cross_validate(param, crossvalidator)
            assert mock_run_script.call_count == 2
            assert mock_interpret.call_count == 1

    # test multi mode
    param = generated_param()
    param.cross_validation.cross_validation_mode = "multi"
    param.cross_validation.nfolds = 2
    fpath = "dials.algorithms.scaling.cross_validation."
    with mock.patch(fpath +
                    "crossvalidator.DialsScaleCrossValidator.run_script"
                    ) as mock_run_script:
        with mock.patch(
                fpath +
                "crossvalidator.DialsScaleCrossValidator.interpret_results"
        ) as mock_interpret:
            param.cross_validation.parameter = "physical.absorption_correction"
            cross_validate(param, crossvalidator)
            assert mock_run_script.call_count == 4
            assert mock_interpret.call_count == 1

            param.cross_validation.parameter = "physical.decay_interval"
            with pytest.raises(ValueError):
                cross_validate(param, crossvalidator)

            param.cross_validation.parameter = "physical.absorption_correction"
            param.cross_validation.parameter_values = ["True", "False"]
            cross_validate(param, crossvalidator)
            assert mock_run_script.call_count == 8
            assert mock_interpret.call_count == 2

            param.cross_validation.parameter = "physical.decay_interval"
            param.cross_validation.parameter_values = ["5.0", "10.0"]
            cross_validate(param, crossvalidator)
            assert mock_run_script.call_count == 12
            assert mock_interpret.call_count == 3

            param.cross_validation.parameter = "model"
            param.cross_validation.parameter_values = ["array", "physical"]
            cross_validate(param, crossvalidator)
            assert mock_run_script.call_count == 16
            assert mock_interpret.call_count == 4

            param.cross_validation.parameter = "physical.lmax"
            param.cross_validation.parameter_values = ["4", "6"]
            cross_validate(param, crossvalidator)
            assert mock_run_script.call_count == 20
            assert mock_interpret.call_count == 5

            param.cross_validation.parameter = "bad_interval"
            with pytest.raises(ValueError):
                cross_validate(param, crossvalidator)

            param.cross_validation.cross_validation_mode = "bad"
            with pytest.raises(ValueError):
                cross_validate(param, crossvalidator)