def make_grid_evaluators(features,
                         weights,
                         targets,
                         e_lda_x,
                         e_lda_css,
                         e_lda_cos,
                         signature,
                         train_validation_test_split=(0.6, 0.2, 0.2),
                         eval_modes=('jit', 'onp', 'onp')):
    """Constructs grid evaluators."""
    if (len(train_validation_test_split) != 3
            or any(frac < 0. for frac in train_validation_test_split)
            or abs(sum(train_validation_test_split) - 1.) > 1e-8):
        raise ValueError('Invalid train_validation_test_split: ',
                         train_validation_test_split)

    features = {
        feature_name: np.array(feature)
        for feature_name, feature in features.items()
    }
    weights = np.array(weights)
    targets = np.array(targets)
    e_lda_x = np.array(e_lda_x)
    e_lda_css = np.array(e_lda_css)
    e_lda_cos = np.array(e_lda_cos)

    num_grids = len(weights)
    grid_indices_partition = np.split(
        np.random.RandomState(0).permutation(num_grids), [
            int(train_validation_test_split[0] * num_grids),
            int(sum(train_validation_test_split[:2]) * num_grids)
        ])

    evaluator_list = []
    for paritition_index, partition in enumerate(
        ['train', 'validation', 'test']):
        grid_indices = grid_indices_partition[paritition_index]
        evaluator = evaluators.GridEvaluator(
            # make copies to ensure memory layout is contiguous
            features={
                feature_name: feature[grid_indices].copy()
                for feature_name, feature in features.items()
            },
            weights=weights[grid_indices].copy(),
            targets=targets[grid_indices].copy(),
            e_lda_x=e_lda_x[grid_indices].copy(),
            e_lda_css=e_lda_css[grid_indices].copy(),
            e_lda_cos=e_lda_cos[grid_indices].copy(),
            signature=signature,
            eval_mode=eval_modes[paritition_index])
        logging.info('GridEvaluator on %s set constructed: %s', partition,
                     evaluator)
        evaluator_list.append(evaluator)

    return evaluator_list
 def test_initialize_grid_evaluator_with_wrong_e_lda_shape(self):
     with self.assertRaisesRegex(
             ValueError,
             r'Wrong shape for e_lda. Expected \(10,\), got \(20,\)'):
         evaluators.GridEvaluator(features={'u': np.zeros(10)},
                                  weights=np.zeros(10),
                                  targets=np.zeros(10),
                                  e_lda_x=np.zeros(10),
                                  e_lda_css=np.zeros(20),
                                  e_lda_cos=np.zeros(10),
                                  signature='f_xc')
 def test_initialize_grid_evaluator_with_wrong_signature(self):
     with self.assertRaisesRegex(
             ValueError,
             'Unknown signature xc, supported values are e_xc and f_xc'):
         evaluators.GridEvaluator(weights=np.zeros(10),
                                  targets=np.zeros(10),
                                  e_lda_x=np.zeros(10),
                                  e_lda_css=np.zeros(10),
                                  e_lda_cos=np.zeros(10),
                                  signature='xc',
                                  features={},
                                  eval_mode='onp')
    def test_eval_wrmsd_e_xc(self, eval_mode):
        num_grids = 10
        raw_weights = np.random.rand(num_grids)

        evaluator = evaluators.GridEvaluator(
            weights=raw_weights * np.sqrt(num_grids / np.sum(raw_weights**2)),
            targets=np.ones(
                num_grids),  # targets are LDA enhancement factors 1.
            e_lda_x=np.random.rand(num_grids),
            e_lda_css=np.random.rand(num_grids),
            e_lda_cos=np.random.rand(num_grids),
            features={},
            signature='f_xc',
            eval_mode=eval_mode)

        self.assertAlmostEqual(
            evaluator.get_eval_wrmsd(xc_functionals.empty_functional)({}, {},
                                                                      {}), 1.)
        self.assertAlmostEqual(
            evaluator.get_eval_wrmsd(xc_functionals.lda_functional)({}, {},
                                                                    {}), 0.)