Ejemplo n.º 1
0
    def test_get_not_ASP_relevant_vars(self):
        def check_params(params, params_from_asp):
            if len(params_from_asp) != len(params):
                return False

            for i, p in enumerate(params_from_asp):
                if p.name != params[i].name:
                    return False
            return True

        params = self.main_program.global_block().all_parameters()
        params_from_asp = ASPHelper._get_not_ASP_relevant_vars(
            self.main_program)
        self.assertTrue(check_params(params, params_from_asp))

        with fluid.program_guard(self.main_program, self.startup_program):
            ASPHelper._minimize(self.optimizer, self.loss, self.main_program,
                                self.startup_program)
        params_from_asp_after_opt = ASPHelper._get_not_ASP_relevant_vars(
            self.main_program)
        self.assertTrue(check_params(params, params_from_asp_after_opt))
Ejemplo n.º 2
0
    def minimize_impl(self,
                      loss,
                      startup_program=None,
                      parameter_list=None,
                      no_grad_set=None):

        optimize_ops, params_grads = ASPHelper._minimize(
            self.inner_opt,
            loss,
            startup_program=startup_program,
            parameter_list=parameter_list,
            no_grad_set=no_grad_set)

        return optimize_ops, params_grads