def test_get_not_ASP_relevant_vars(self): def check_params(params, params_from_asp): if len(params_from_asp) != len(params): return False for i, p in enumerate(params_from_asp): if p.name != params[i].name: return False return True params = self.main_program.global_block().all_parameters() params_from_asp = ASPHelper._get_not_ASP_relevant_vars( self.main_program) self.assertTrue(check_params(params, params_from_asp)) with fluid.program_guard(self.main_program, self.startup_program): ASPHelper._minimize(self.optimizer, self.loss, self.main_program, self.startup_program) params_from_asp_after_opt = ASPHelper._get_not_ASP_relevant_vars( self.main_program) self.assertTrue(check_params(params, params_from_asp_after_opt))
def minimize_impl(self, loss, startup_program=None, parameter_list=None, no_grad_set=None): optimize_ops, params_grads = ASPHelper._minimize( self.inner_opt, loss, startup_program=startup_program, parameter_list=parameter_list, no_grad_set=no_grad_set) return optimize_ops, params_grads