def test_randomize_instruction_list(self):
        expected_instruction_list = [
            instructions.UTransformInstruction('u', 'x2'),
            instructions.AdditionInstruction('enhancement_factor', 'c00',
                                             'enhancement_factor'),
            instructions.MultiplicationInstruction('linear_term', 'c10', 'w'),
            instructions.AdditionInstruction('enhancement_factor',
                                             'enhancement_factor',
                                             'linear_term'),
            instructions.MultiplicationInstruction('linear_term', 'c01', 'u'),
            instructions.AdditionInstruction('enhancement_factor',
                                             'enhancement_factor',
                                             'linear_term'),
        ]
        with mock.patch.object(mutators.EnhancementFactorMutator,
                               'get_random_instruction',
                               side_effect=expected_instruction_list):
            mutator = mutators.EnhancementFactorMutator(
                mutation_pool={'randomize_instruction_list': 1.})
            enhancement_factor = copy.deepcopy(
                enhancement_factors.f_x_wb97mv_short)

            new_enhancement_factor, mutation_type, _, _ = (mutator.mutate(
                enhancement_factor=enhancement_factor, verbose=False))

            self.assertEqual(enhancement_factor,
                             enhancement_factors.f_x_wb97mv_short)
            self.assertEqual(mutation_type, 'randomize_instruction_list')
            self.assertEqual(new_enhancement_factor.instruction_list,
                             expected_instruction_list)
Exemplo n.º 2
0
    def setUp(self):
        super().setUp()

        self.functional = xc_functionals.XCFunctional(
            # objective = (x + 1) ** 2 + (y + 1) ** 2
            f_x=enhancement_factors.EnhancementFactor(
                shared_parameter_names=['x'],
                variable_names=['var1', 'var2', 'enhancement_factor'],
                instruction_list=[
                    instructions.AdditionBy1Instruction('var1', 'x'),
                    instructions.Power2Instruction('var2', 'var1'),
                    instructions.AdditionInstruction('enhancement_factor',
                                                     'enhancement_factor',
                                                     'var2'),
                ]),
            f_css=enhancement_factors.EnhancementFactor(
                shared_parameter_names=['y'],
                variable_names=['var1', 'var2', 'enhancement_factor'],
                instruction_list=[
                    instructions.AdditionBy1Instruction('var1', 'y'),
                    instructions.Power2Instruction('var2', 'var1'),
                    instructions.AdditionInstruction('enhancement_factor',
                                                     'enhancement_factor',
                                                     'var2'),
                ]),
            f_cos=enhancement_factors.EnhancementFactor(
                variable_names=['enhancement_factor']),
        )
        self.evaluator = evaluators.Evaluator(num_grids_for_mols=[5, 5],
                                              rho_weights=np.ones(10) * 0.2,
                                              formula_matrix=np.eye(2),
                                              targets=np.zeros(2),
                                              sample_weights=np.ones(2),
                                              e_lda_x=np.ones(10),
                                              e_lda_css=np.ones(10),
                                              e_lda_cos=np.zeros(10),
                                              features={})
        self.optimizer = optimizers.CMAESOptimizer(evaluator=self.evaluator,
                                                   initial_parameters_mean=0.,
                                                   initial_parameters_std=0.,
                                                   sigma0=1.)
Exemplo n.º 3
0
    def setUp(self):
        super().setUp()

        self.num_features = 2
        self.num_shared_parameters = 2
        self.num_variables = 3

        self.features = {
            f'feature_{i}': np.random.rand(5)
            for i in range(self.num_features)
        }
        self.shared_parameters = {
            f'shared_parameter_{i}': np.random.rand()
            for i in range(self.num_shared_parameters)
        }
        self.bound_parameters = {'gamma_utransform': np.random.rand()}
        self.parameters = {**self.shared_parameters, **self.bound_parameters}
        self.variables = {
            f'variable_{i}': np.zeros(5)
            for i in range(self.num_variables - 1)
        }
        self.variables.update({'enhancement_factor': np.zeros(5)})

        self.enhancement_factor = enhancement_factors.EnhancementFactor(
            feature_names=list(self.features.keys()),
            shared_parameter_names=list(self.shared_parameters.keys()),
            variable_names=list(self.variables.keys()),
            instruction_list=[
                instructions.MultiplicationInstruction('variable_0',
                                                       'feature_0',
                                                       'shared_parameter_0'),
                instructions.AdditionInstruction('variable_1', 'feature_1',
                                                 'shared_parameter_1'),
                instructions.AdditionInstruction('variable_1', 'variable_1',
                                                 'variable_0'),
                instructions.Power2Instruction('enhancement_factor',
                                               'variable_1'),
                instructions.UTransformInstruction('enhancement_factor',
                                                   'enhancement_factor')
            ])
    def test_randomize_instruction_list_fixed_num_instructions(self):
        expected_instruction_list = [
            instructions.AdditionInstruction('enhancement_factor', 'c10', 'u'),
            instructions.MultiplicationInstruction('enhancement_factor', 'c10',
                                                   'u'),
            instructions.DivisionInstruction('enhancement_factor', 'c10', 'u')
        ]
        with mock.patch.object(mutators.EnhancementFactorMutator,
                               'get_random_instruction',
                               side_effect=expected_instruction_list):
            mutator = mutators.EnhancementFactorMutator(
                mutation_pool={'randomize_instruction_list': 1.})
            enhancement_factor = copy.deepcopy(
                enhancement_factors.f_x_wb97mv_short)

            new_instruction_list, _, _, _ = mutator.randomize_instruction_list(
                enhancement_factor, num_instructions=2)

            self.assertEqual(new_instruction_list,
                             expected_instruction_list[:2])
Exemplo n.º 5
0
                              instructions.AdditionBy1Instruction(
                                  'enhancement_factor', 'enhancement_factor'),
                          ])

# B97 enhancement factor as a function of u
f_b97_u = EnhancementFactor(
    feature_names=['u'],
    shared_parameter_names=['c0', 'c1', 'c2'],
    variable_names=['c1u', 'c2u2', 'enhancement_factor'],
    instruction_list=[
        # calculation of u
        instructions.MultiplicationInstruction('c1u', 'u', 'c1'),
        instructions.Power2Instruction('c2u2', 'u'),
        # power series
        instructions.MultiplicationInstruction('c2u2', 'c2', 'c2u2'),
        instructions.AdditionInstruction('enhancement_factor', 'c0', 'c1u'),
        instructions.AdditionInstruction('enhancement_factor',
                                         'enhancement_factor', 'c2u2'),
    ])

# B97 enhancement factor as a function of u (short version)
f_b97_u_short = EnhancementFactor(
    feature_names=['u'],
    shared_parameter_names=['c0', 'c1', 'c2'],
    variable_names=['u2', 'enhancement_factor'],
    instruction_list=[
        instructions.AdditionInstruction('enhancement_factor',
                                         'enhancement_factor', 'c0'),
        instructions.MultiplicationAdditionInstruction('enhancement_factor',
                                                       'c1', 'u'),
        instructions.Power2Instruction('u2', 'u'),
Exemplo n.º 6
0
class EnhancementFactorTest(parameterized.TestCase):
    def setUp(self):
        super().setUp()

        self.num_features = 2
        self.num_shared_parameters = 2
        self.num_variables = 3

        self.features = {
            f'feature_{i}': np.random.rand(5)
            for i in range(self.num_features)
        }
        self.shared_parameters = {
            f'shared_parameter_{i}': np.random.rand()
            for i in range(self.num_shared_parameters)
        }
        self.bound_parameters = {'gamma_utransform': np.random.rand()}
        self.parameters = {**self.shared_parameters, **self.bound_parameters}
        self.variables = {
            f'variable_{i}': np.zeros(5)
            for i in range(self.num_variables - 1)
        }
        self.variables.update({'enhancement_factor': np.zeros(5)})

        self.enhancement_factor = enhancement_factors.EnhancementFactor(
            feature_names=list(self.features.keys()),
            shared_parameter_names=list(self.shared_parameters.keys()),
            variable_names=list(self.variables.keys()),
            instruction_list=[
                instructions.MultiplicationInstruction('variable_0',
                                                       'feature_0',
                                                       'shared_parameter_0'),
                instructions.AdditionInstruction('variable_1', 'feature_1',
                                                 'shared_parameter_1'),
                instructions.AdditionInstruction('variable_1', 'variable_1',
                                                 'variable_0'),
                instructions.Power2Instruction('enhancement_factor',
                                               'variable_1'),
                instructions.UTransformInstruction('enhancement_factor',
                                                   'enhancement_factor')
            ])

    def test_constructor(self):
        self.assertEqual(self.enhancement_factor.num_features,
                         self.num_features)
        self.assertEqual(self.enhancement_factor.num_parameters,
                         self.num_shared_parameters + 1)  # 1 from UTransform
        self.assertEqual(self.enhancement_factor.num_variables,
                         self.num_variables)

    def test_constructor_without_enhancement_factor_in_variable_names(self):
        with self.assertRaisesRegex(
                ValueError,
                '"enhancement_factor" not found in variable_names.'):
            enhancement_factors.EnhancementFactor(feature_names=[],
                                                  shared_parameter_names=[],
                                                  variable_names=[],
                                                  instruction_list=[])

    def test_constructor_with_repeated_name(self):
        with self.assertRaisesRegex(ValueError,
                                    'Repeated names found in input.'):
            enhancement_factors.EnhancementFactor(
                feature_names=['var'],
                shared_parameter_names=['var'],
                variable_names=['enhancement_factor'],
                instruction_list=[])

    def test_constructor_with_wrong_instruction_type(self):
        with self.assertRaisesRegex(
                TypeError, r"1 is of type <class 'int'>, not an "
                'instance of instructions.Instruction'):
            enhancement_factors.EnhancementFactor(
                feature_names=list(self.features.keys()),
                shared_parameter_names=list(self.shared_parameters.keys()),
                variable_names=list(self.variables.keys()),
                instruction_list=[1])

    @parameterized.parameters(
        (instructions.Power2Instruction('variable_0', 'var'),
         (r'Instruction variable_0 = var \*\* 2 contains invalid input argument '
          'var')),
        (instructions.AdditionInstruction('variable_0', 'shared_parameter_1',
                                          'gamma_utransform'),
         (r'Instruction variable_0 = shared_parameter_1 \+ gamma_utransform '
          'contains invalid input argument gamma_utransform')),
    )
    def test_constructor_with_invalid_input(self, instruction, error_message):
        with self.assertRaisesRegex(ValueError, error_message):
            enhancement_factors.EnhancementFactor(
                feature_names=list(self.features.keys()),
                shared_parameter_names=list(self.shared_parameters.keys()),
                variable_names=list(self.variables.keys()),
                instruction_list=[instruction])

    @parameterized.parameters(
        (instructions.Power2Instruction('feature_0', 'shared_parameter_0'),
         (r'Instruction feature_0 = shared_parameter_0 \*\* 2 contains '
          'invalid output argument feature_0')),
        (instructions.AdditionInstruction('feature_1', 'shared_parameter_1',
                                          'variable_1'),
         (r'Instruction feature_1 = shared_parameter_1 \+ variable_1 contains '
          'invalid output argument feature_1')),
        (instructions.Power4Instruction('bound_parameter_1',
                                        'shared_parameter_1'),
         (r'Instruction bound_parameter_1 = shared_parameter_1 \*\* 4 contains '
          'invalid output argument bound_parameter_1')),
    )
    def test_constructor_with_invalid_output(self, instruction, error_message):
        with self.assertRaisesRegex(ValueError, error_message):
            enhancement_factors.EnhancementFactor(
                feature_names=list(self.features.keys()),
                shared_parameter_names=list(self.shared_parameters.keys()),
                variable_names=list(self.variables.keys()),
                instruction_list=[instruction])

    @parameterized.parameters(False, True)
    def test_eval(self, use_jax):
        tmp = ((self.features['feature_0'] *
                self.parameters['shared_parameter_0']) +
               (self.features['feature_1'] +
                self.parameters['shared_parameter_1']))
        tmp = self.parameters['gamma_utransform'] * tmp**2
        expected_f = tmp / (1. + tmp)

        f = self.enhancement_factor.eval(self.features,
                                         self.parameters,
                                         use_jax=use_jax)

        np.testing.assert_allclose(f, expected_f)

    @parameterized.parameters(False, True)
    def test_b97_u_enhancement_factor(self, use_jax):
        gamma_x = 0.004
        coeffs_x = 0.8094, 0.5073, 0.7481
        x = np.random.rand(5)
        u = gga.u_b97(x, gamma=gamma_x)
        expected_f = gga.f_b97(x)

        f = enhancement_factors.f_b97_u.eval(features={'u': u},
                                             parameters={
                                                 'c0': coeffs_x[0],
                                                 'c1': coeffs_x[1],
                                                 'c2': coeffs_x[2],
                                             },
                                             use_jax=use_jax)

        np.testing.assert_allclose(f, expected_f)

    @parameterized.parameters(False, True)
    def test_b97_u_short_enhancement_factor(self, use_jax):
        gamma_x = 0.004
        coeffs_x = 0.8094, 0.5073, 0.7481
        x = np.random.rand(5)
        u = gga.u_b97(x, gamma=gamma_x)
        expected_f = gga.f_b97(x)

        f = enhancement_factors.f_b97_u_short.eval(features={'u': u},
                                                   parameters={
                                                       'c0': coeffs_x[0],
                                                       'c1': coeffs_x[1],
                                                       'c2': coeffs_x[2],
                                                   },
                                                   use_jax=use_jax)

        np.testing.assert_allclose(f, expected_f)

    @parameterized.parameters(False, True)
    def test_b97_x2_enhancement_factor(self, use_jax):
        gamma_x = 0.004
        coeffs_x = 0.8094, 0.5073, 0.7481
        x = np.random.rand(5)
        x2 = (1 / 2)**(-2 / 3) * x**2
        expected_f = gga.f_b97(x)

        f = enhancement_factors.f_b97_x2.eval(features={'x2': x2},
                                              parameters={
                                                  'c0': coeffs_x[0],
                                                  'c1': coeffs_x[1],
                                                  'c2': coeffs_x[2],
                                                  'gamma': gamma_x
                                              },
                                              use_jax=use_jax)

        np.testing.assert_allclose(f, expected_f)

    @parameterized.parameters(False, True)
    def test_b97_x2_short_enhancement_factor(self, use_jax):
        gamma_x = 0.004
        coeffs_x = 0.8094, 0.5073, 0.7481
        x = np.random.rand(5)
        x2 = (1 / 2)**(-2 / 3) * x**2
        expected_f = gga.f_b97(x)

        f = enhancement_factors.f_b97_x2_short.eval(features={'x2': x2},
                                                    parameters={
                                                        'c0': coeffs_x[0],
                                                        'c1': coeffs_x[1],
                                                        'c2': coeffs_x[2],
                                                        'gamma_utransform':
                                                        gamma_x
                                                    },
                                                    use_jax=use_jax)

        np.testing.assert_allclose(f, expected_f)

    @parameterized.parameters(
        (enhancement_factors.f_x_wb97mv, enhancement_factors.f_css_wb97mv,
         enhancement_factors.f_cos_wb97mv, 'gamma'),
        (enhancement_factors.f_x_wb97mv_short,
         enhancement_factors.f_css_wb97mv_short,
         enhancement_factors.f_cos_wb97mv_short, 'gamma_utransform'),
    )
    def test_wb97mv_enhancement_factors(self, f_x_wb97mv, f_css_wb97mv,
                                        f_cos_wb97mv, gamma_key):
        rho = np.random.rand(5)
        x = np.random.rand(5)
        tau = np.random.rand(5)
        x2 = (1 / 2)**(-2 / 3) * x**2
        t = mgga.get_mgga_t(rho, tau, polarized=False)
        w = (t - 1) / (t + 1)
        expected_f_x = mgga.f_b97m(
            x,
            t,
            gamma=mgga.WB97MV_PARAMS['gamma_x'],
            power_series=mgga.WB97MV_PARAMS['power_series_x'],
            polarized=False)
        expected_f_css = mgga.f_b97m(
            x,
            t,
            gamma=mgga.WB97MV_PARAMS['gamma_ss'],
            power_series=mgga.WB97MV_PARAMS['power_series_ss'],
            polarized=False)
        expected_f_cos = mgga.f_b97m(
            x,
            t,
            gamma=mgga.WB97MV_PARAMS['gamma_os'],
            power_series=mgga.WB97MV_PARAMS['power_series_os'],
            polarized=False)

        f_x = f_x_wb97mv.eval(features={
            'x2': x2,
            'w': w
        },
                              parameters={
                                  'c00':
                                  mgga.WB97MV_PARAMS['power_series_x'][0][2],
                                  'c10':
                                  mgga.WB97MV_PARAMS['power_series_x'][1][2],
                                  'c01':
                                  mgga.WB97MV_PARAMS['power_series_x'][2][2],
                                  gamma_key: mgga.WB97MV_PARAMS['gamma_x']
                              })
        f_css = f_css_wb97mv.eval(
            features={
                'x2': x2,
                'w': w
            },
            parameters={
                'c00': mgga.WB97MV_PARAMS['power_series_ss'][0][2],
                'c10': mgga.WB97MV_PARAMS['power_series_ss'][1][2],
                'c20': mgga.WB97MV_PARAMS['power_series_ss'][2][2],
                'c43': mgga.WB97MV_PARAMS['power_series_ss'][3][2],
                'c04': mgga.WB97MV_PARAMS['power_series_ss'][4][2],
                gamma_key: mgga.WB97MV_PARAMS['gamma_ss']
            })
        f_cos = f_cos_wb97mv.eval(
            features={
                'x2': x2,
                'w': w
            },
            parameters={
                'c00': mgga.WB97MV_PARAMS['power_series_os'][0][2],
                'c10': mgga.WB97MV_PARAMS['power_series_os'][1][2],
                'c20': mgga.WB97MV_PARAMS['power_series_os'][2][2],
                'c60': mgga.WB97MV_PARAMS['power_series_os'][3][2],
                'c21': mgga.WB97MV_PARAMS['power_series_os'][4][2],
                'c61': mgga.WB97MV_PARAMS['power_series_os'][5][2],
                gamma_key: mgga.WB97MV_PARAMS['gamma_os']
            })

        np.testing.assert_allclose(f_x, expected_f_x)
        np.testing.assert_allclose(f_css, expected_f_css)
        np.testing.assert_allclose(f_cos, expected_f_cos)

    def test_convert_enhancement_factor_to_and_from_dict(self):
        self.assertEqual(
            self.enhancement_factor,
            enhancement_factors.EnhancementFactor.from_dict(
                self.enhancement_factor.to_dict()))

    @parameterized.parameters(
        enhancement_factors.f_empty,
        enhancement_factors.f_lda,
        enhancement_factors.f_b97_u,
        enhancement_factors.f_b97_u_short,
        enhancement_factors.f_b97_x2,
        enhancement_factors.f_b97_x2_short,
        enhancement_factors.f_x_wb97mv,
        enhancement_factors.f_css_wb97mv,
        enhancement_factors.f_cos_wb97mv,
        enhancement_factors.f_x_wb97mv_short,
        enhancement_factors.f_css_wb97mv_short,
        enhancement_factors.f_cos_wb97mv_short,
    )
    def test_make_isomorphic_copy(self, enhancement_factor):
        features = {
            feature_name: np.random.rand(5)
            for feature_name in enhancement_factor.feature_names
        }
        shared_parameters = {
            parameter_name: np.random.rand()
            for parameter_name in enhancement_factor.shared_parameter_names
        }
        renamed_shared_parameters = {
            (enhancement_factor._isomorphic_copy_shared_parameter_prefix +
             str(index)): value
            for index, value in enumerate(shared_parameters.values())
        }
        bound_parameters = {
            parameter_name: np.random.rand()
            for parameter_name in enhancement_factor.bound_parameter_names
        }

        enhancement_factor_copy = enhancement_factor.make_isomorphic_copy()

        np.testing.assert_allclose(
            enhancement_factor.eval(features=features,
                                    parameters={
                                        **shared_parameters,
                                        **bound_parameters
                                    }),
            enhancement_factor_copy.eval(features=features,
                                         parameters={
                                             **renamed_shared_parameters,
                                             **bound_parameters
                                         }))

    def test_make_isomorphic_copy_of_f_x_wb97mv_short(self):
        f_x_wb97mv_copy = enhancement_factors.f_x_wb97mv_short.make_isomorphic_copy(
            feature_names=['rho', 'x2', 'w'],
            num_shared_parameters=10,
            num_variables=10)

        self.assertEqual(f_x_wb97mv_copy.feature_names, ['rho', 'x2', 'w'])
        self.assertEqual(f_x_wb97mv_copy.num_shared_parameters, 10)
        self.assertEqual(f_x_wb97mv_copy.shared_parameter_names, [
            f_x_wb97mv_copy._isomorphic_copy_shared_parameter_prefix +
            str(index) for index in range(10)
        ])
        self.assertEqual(f_x_wb97mv_copy.variable_names, [
            f_x_wb97mv_copy._isomorphic_copy_variable_prefix + str(index)
            for index in range(9)
        ] + ['enhancement_factor'])

    def test_make_isomorphic_copy_enhancement_factor_variable_location(self):
        f_x_wb97mv_shuffled = copy.deepcopy(
            enhancement_factors.f_x_wb97mv_short)
        f_x_wb97mv_shuffled.variable_names.remove('enhancement_factor')
        f_x_wb97mv_shuffled.variable_names.insert(
            np.random.randint(len(f_x_wb97mv_shuffled.variable_names)),
            'enhancement_factor')
        self.assertEqual(
            enhancement_factors.f_x_wb97mv_short.make_isomorphic_copy(),
            f_x_wb97mv_shuffled.make_isomorphic_copy())

    def test_make_isomorphic_copy_repeated_feature_names(self):
        with self.assertRaisesRegex(ValueError, 'Repeated feature names'):
            enhancement_factors.f_b97_u.make_isomorphic_copy(
                feature_names=['u', 'u'])

    def test_make_isomorphic_copy_wrong_feature_names(self):
        with self.assertRaisesRegex(
                ValueError,
                r"feature_names \['rho', 'x2'\] is not a superset of feature_names of "
                r"current instance \['w', 'x2'\]"):
            enhancement_factors.f_x_wb97mv.make_isomorphic_copy(
                feature_names=['rho', 'x2'])

    def test_make_isomorphic_copy_wrong_num_shared_parameters(self):
        with self.assertRaisesRegex(
                ValueError, 'num_shared_parameters 5 is smaller than '
                'that of current instance 6'):
            enhancement_factors.f_cos_wb97mv_short.make_isomorphic_copy(
                num_shared_parameters=5)

    def test_make_isomorphic_copy_wrong_num_variables(self):
        with self.assertRaisesRegex(
                ValueError, 'num_variables 3 is smaller than '
                'that of current instance 5'):
            enhancement_factors.f_cos_wb97mv_short.make_isomorphic_copy(
                num_variables=3)

    @parameterized.parameters(
        (enhancement_factors.f_b97_u, 3),
        (enhancement_factors.f_b97_u_short, 3),
        (enhancement_factors.f_b97_x2, 4),
        (enhancement_factors.f_b97_x2_short, 4),
        (enhancement_factors.f_x_wb97mv_short, 4),
    )
    def test_num_used_parameters(self, enhancement_factor,
                                 expected_num_used_parameters):
        self.assertEqual(enhancement_factor.num_used_parameters,
                         expected_num_used_parameters)
        self.assertEqual(
            enhancement_factor.make_isomorphic_copy(
                num_shared_parameters=20).num_used_parameters,
            expected_num_used_parameters)

    def test_get_symbolic_expression(self):
        c0, c1, c2, gamma, x = sympy.symbols('c0 c1 c2 gamma_utransform x')
        self.assertEqual(
            enhancement_factors.f_b97_x2_short.get_symbolic_expression(
                latex=False, simplify=False),
            (c0 + c1 * gamma * x**2 /
             (gamma * x**2 + 1.) + c2 * gamma**2 * x**4 /
             (gamma * x**2 + 1.)**2))

    def test_get_symbolic_expression_latex(self):
        self.assertEqual(
            enhancement_factors.f_b97_x2_short.get_symbolic_expression(
                latex=True, simplify=False),
            r'c_{0} + \frac{c_{1} \gamma_{u} x^{2}}{\gamma_{u} x^{2} + 1.0} + '
            r'\frac{c_{2} \gamma_{u}^{2} x^{4}}{\left(\gamma_{u} x^{2} + '
            r'1.0\right)^{2}}')
class InstructionTest(parameterized.TestCase):

    # general tests
    def test_instructions_do_not_have_repeated_bound_parameters(self):
        bound_parameters = list(
            itertools.chain(*[
                instruction_class.get_bound_parameters() for instruction_class
                in instructions.INSTRUCTION_CLASSES.values()
            ]))
        self.assertEqual(len(bound_parameters), len(set(bound_parameters)))

    @parameterized.parameters(*instructions.Instruction.__subclasses__())
    def test_class_attributes(self, instruction_class):
        self.assertIn(instruction_class.get_num_inputs(), [1, 2])
        self.assertGreaterEqual(instruction_class.get_num_bound_parameters(),
                                0)

    def test_constructor_with_wrong_number_of_arguments(self):
        with self.assertRaisesRegex(
                ValueError,
                'Power4Instruction: wrong number of arguments. Expected 2, got 3'
        ):
            instructions.Power4Instruction('a', 'b', 'c')

    @parameterized.parameters(
        (instructions.Power2Instruction(
            'a', 'b'), instructions.Power2Instruction('a', 'b')),
        (instructions.AdditionInstruction(
            'a', 'b', 'c'), instructions.AdditionInstruction('a', 'b', 'c')),
        (instructions.UTransformInstruction(
            'a', 'b'), instructions.UTransformInstruction('a', 'b')))
    def test_eq(self, instruction1, instruction2):
        self.assertEqual(instruction1, instruction2)

    @parameterized.parameters(
        (instructions.AdditionInstruction('a', 'b', 'c'),
         instructions.MultiplicationInstruction('a', 'b', 'c')),
        (instructions.AdditionInstruction(
            'a', 'b', 'c'), instructions.AdditionInstruction('a', 'b', 'd')),
        (instructions.AdditionInstruction(
            'a', 'b', 'c'), instructions.Power2Instruction('a', 'b')))
    def test_eq_false(self, instruction1, instruction2):
        self.assertNotEqual(instruction1, instruction2)

    # test instructions without bound parameters
    @parameterized.parameters(
        (instructions.AdditionBy1Instruction, lambda a: a + 1.),
        (instructions.Power2Instruction, lambda a: a**2),
        (instructions.Power3Instruction, lambda a: a**3),
        (instructions.Power4Instruction, lambda a: a**4),
        (instructions.Power6Instruction, lambda a: a**6),
        (instructions.SquareRootInstruction, np.sqrt),
        (instructions.CubeRootInstruction, np.cbrt),
        (instructions.Log1PInstruction, np.log1p),
        (instructions.ExpInstruction, np.exp),
    )
    def test_unary_instruction(self, instruction_class, function):
        workspace = {
            'a': np.random.rand(),
        }
        instruction_class('b', 'a').apply(workspace, use_jax=False)
        self.assertAlmostEqual(workspace['b'], function(workspace['a']))
        instruction_class('b', 'a').apply(workspace, use_jax=True)
        self.assertAlmostEqual(workspace['b'], function(workspace['a']))

    @parameterized.parameters(
        (instructions.AdditionInstruction, lambda a, b: a + b),
        (instructions.SubtractionInstruction, lambda a, b: a - b),
        (instructions.MultiplicationInstruction, lambda a, b: a * b),
        (instructions.DivisionInstruction, lambda a, b: a / b),
    )
    def test_binary_instruction(self, instruction_class, function):
        workspace = {'a': np.random.rand(), 'b': np.random.rand()}
        instruction_class('c', 'a', 'b').apply(workspace, use_jax=False)
        self.assertAlmostEqual(workspace['c'],
                               function(workspace['a'], workspace['b']))
        instruction_class('c', 'a', 'b').apply(workspace, use_jax=False)
        self.assertAlmostEqual(workspace['c'],
                               function(workspace['a'], workspace['b']))

    @parameterized.parameters(
        instructions.AdditionBy1Instruction, instructions.Power2Instruction,
        instructions.Power3Instruction, instructions.Power4Instruction,
        instructions.Power6Instruction, instructions.SquareRootInstruction,
        instructions.CubeRootInstruction, instructions.Log1PInstruction,
        instructions.ExpInstruction)
    def test_unary_instruction_apply(self, instruction_class):
        input1 = np.random.rand()
        instruction = instruction_class('output', 'input1')
        workspace = {'input1': input1}
        symbolic_workspace = {'input1': sympy.Symbol('input1')}

        instruction.apply(workspace)
        instruction.sympy_apply(symbolic_workspace)

        self.assertAlmostEqual(
            workspace['output'],
            symbolic_workspace['output'].subs({'input1': input1}))

    @parameterized.parameters(instructions.AdditionInstruction,
                              instructions.SubtractionInstruction,
                              instructions.MultiplicationInstruction,
                              instructions.DivisionInstruction)
    def test_binary_instruction_apply(self, instruction_class):
        input1 = np.random.rand()
        input2 = np.random.rand()
        instruction = instruction_class('output', 'input1', 'input2')
        workspace = {'input1': input1, 'input2': input2}
        symbolic_workspace = {
            'input1': sympy.Symbol('input1'),
            'input2': sympy.Symbol('input2')
        }

        instruction.apply(workspace)
        instruction.sympy_apply(symbolic_workspace)

        self.assertAlmostEqual(
            workspace['output'], symbolic_workspace['output'].subs({
                'input1':
                input1,
                'input2':
                input2
            }))

    # test instructions with bound parameters
    @parameterized.parameters(
        [0., 0., 0.],
        [0., 1., 0.],
        [1., 0., 0.],
        [1., 1., .5],
    )
    def test_utransform_instruction(self, x, gamma_utransform, expected_y):
        workspace = {'x': x, 'gamma_utransform': gamma_utransform}
        instructions.UTransformInstruction('output', 'x').apply(workspace)
        self.assertAlmostEqual(workspace['output'], expected_y)

    @parameterized.parameters(False, True)
    def test_pbex_instruction(self, use_jax):
        workspace = {
            'x': np.random.rand(5),
            'kappa_pbex': np.random.rand(),
            'mu_pbex': np.random.rand()
        }

        instructions.PBEXInstruction('output', 'x').apply(workspace,
                                                          use_jax=use_jax)

        np.testing.assert_allclose(
            workspace.pop('output'),
            gga.f_x_pbe(**{
                key.partition('_')[0]: value
                for key, value in workspace.items()
            }))

    @parameterized.parameters(False, True)
    def test_rpbex_instruction(self, use_jax):
        workspace = {
            'x': np.random.rand(5),
            'kappa_rpbex': np.random.rand(),
            'mu_rpbex': np.random.rand()
        }

        instructions.RPBEXInstruction('output', 'x').apply(workspace,
                                                           use_jax=use_jax)

        np.testing.assert_allclose(
            workspace.pop('output'),
            gga.f_x_rpbe(**{
                key.partition('_')[0]: value
                for key, value in workspace.items()
            }))

    @parameterized.parameters(False, True)
    def test_b88_instruction(self, use_jax):
        workspace = {'x': np.random.rand(5), 'beta_b88x': np.random.rand()}

        instructions.B88XInstruction('output', 'x').apply(workspace,
                                                          use_jax=use_jax)

        np.testing.assert_allclose(
            workspace.pop('output'),
            gga.f_x_b88(**{
                key.partition('_')[0]: value
                for key, value in workspace.items()
            }))

    @parameterized.parameters(False, True)
    def test_pbec_instruction(self, use_jax):
        workspace = {
            'rho': np.random.rand(5),
            'sigma': np.random.rand(5),
            'beta_pbec': np.random.rand(),
            'gamma_pbec': np.random.rand()
        }

        instructions.PBECInstruction('output', 'rho',
                                     'sigma').apply(workspace, use_jax=use_jax)

        np.testing.assert_allclose(
            workspace.pop('output'),
            gga.e_c_pbe_unpolarized(**{
                key.partition('_')[0]: value
                for key, value in workspace.items()
            }))

    # test conversion
    @parameterized.parameters(*instructions.Instruction.__subclasses__())
    def test_convert_instruction_to_and_from_list(self, instruction_class):
        instruction = instruction_class(
            *[f'arg{i}' for i in range(instruction_class.get_num_args())])

        instruction_from_list = instructions.Instruction.from_list(
            instruction.to_list())

        self.assertEqual(instruction, instruction_from_list)

    def test_from_list_with_wrong_instruction_name(self):
        with self.assertRaisesRegex(
                ValueError,
                'Invalid instruction class name: UnknownInstruction'):
            instructions.Instruction.from_list(
                ['UnknownInstruction', 'a', 'b', 'c'])

    # test helper functions
    def test_is_unary_instruction_name(self):
        self.assertTrue(
            instructions.is_unary_instruction_name('Power2Instruction'))
        self.assertTrue(
            instructions.is_unary_instruction_name('UTransformInstruction'))
        self.assertFalse(
            instructions.is_unary_instruction_name('AdditionInstruction'))

    def test_is_binary_instruction_name(self):
        self.assertTrue(
            instructions.is_binary_instruction_name(
                'MultiplicationInstruction'))
        self.assertTrue(
            instructions.is_binary_instruction_name('PBECInstruction'))
        self.assertFalse(
            instructions.is_binary_instruction_name('Additionby1Instruction'))

    def test_get_unary_instruction_names_from_list(self):
        self.assertEqual(
            instructions.get_unary_instruction_names_from_list([
                'AdditionInstruction', 'Power2Instruction',
                'UnknownInstruction'
            ]), ['Power2Instruction'])

    def test_get_binary_instruction_names_from_list(self):
        self.assertEqual(
            instructions.get_binary_instruction_names_from_list([
                'AdditionInstruction', 'Power2Instruction',
                'UnknownInstruction'
            ]), ['AdditionInstruction'])

    @parameterized.parameters(0, 1, 2, 3)
    def test_get_instruction_names_with_signature_num_inputs(self, num_inputs):
        instruction_names = instructions.get_instruction_names_with_signature(
            num_inputs=num_inputs)
        for instruction_name in instruction_names:
            self.assertEqual(
                instructions.INSTRUCTION_CLASSES[instruction_name].
                get_num_inputs(), num_inputs)

    @parameterized.parameters(0, 1, 2, 3)
    def test_get_instruction_names_with_num_bound_parameters(
            self, num_bound_parameters):
        instruction_names = instructions.get_instruction_names_with_signature(
            num_bound_parameters=num_bound_parameters)
        for instruction_name in instruction_names:
            self.assertEqual(
                instructions.INSTRUCTION_CLASSES[instruction_name].
                get_num_bound_parameters(), num_bound_parameters)

    @parameterized.parameters(0, 1, 2, 3)
    def test_get_instruction_names_with_max_num_bound_parameters(
            self, max_num_bound_parameters):
        instruction_names = instructions.get_instruction_names_with_signature(
            max_num_bound_parameters=max_num_bound_parameters)
        for instruction_name in instruction_names:
            self.assertLessEqual(
                instructions.INSTRUCTION_CLASSES[instruction_name].
                get_num_bound_parameters(), max_num_bound_parameters)