def setOptimizingParams(self, damping_factor_obj: float = 1.0, damping_update_factor_obj: float = 0.999, damping_update_frequency_obj: int = 5, damping_factor_probe: float = 1.0, damping_update_factor_probe: float = 0.999, damping_update_frequency_probe: float = 5): with self.graph.as_default(): self._optparams.obj_optimizer = Curveball( input_var=self._tf_obj, predictions_fn=self._training_predictions_as_obj_fn, loss_fn=self._training_loss_fn, damping_factor=damping_factor_obj, damping_update_factor=damping_update_factor_obj, damping_update_frequency=damping_update_frequency_obj, name='obj_opt') self._optparams.obj_minimize_op = self._optparams.obj_optimizer.minimize( ) if self._probe_recons: self._optparams.probe_optimizer = Curveball( input_var=self._tf_probe, predictions_fn=self._training_predictions_as_probe_fn, loss_fn=self._training_loss_fn, damping_factor=damping_factor_probe, damping_update_factor=damping_update_factor_probe, damping_update_frequency=damping_update_frequency_probe, name='probe_opt') self._optparams.probe_minimize_op = self._optparams.probe_optimizer.minimize( ) self._optparams.training_loss_tensor = self._optparams.obj_optimizer._loss_fn_tensor self._optimizers_defined = True
def __init__(self, input_var: tf.Variable, predictions_fn: Callable, loss_fn: Callable, diag_hessian_fn: Callable = None, initial_update_delay: int = 0, update_frequency: int = 1, **extra_init_kwargs: int): super().__init__(initial_update_delay, update_frequency) self._optimizer = Curveball(input_var=input_var, predictions_fn=predictions_fn, loss_fn=loss_fn, diag_hessian_fn=diag_hessian_fn, **extra_init_kwargs)
def setOptimizingParams(self, damping_factor_obj: float = 1.0, damping_update_factor_obj: float = 0.999, damping_update_frequency_obj: int = 5, damping_factor_probe: float = 1.0, damping_update_factor_probe: float = 0.999, damping_update_frequency_probe: float = 5, update_cond_threshold_low: float = 0.5, update_cond_threshold_high: float = 1.5): if self._loss_type in ["poisson", "poisson_surrogate"]: loss_hessian_fn = self._training_loss_hessian_fn squared_loss = False else: loss_hessian_fn = None squared_loss = True with self.graph.as_default(): self._optparams.obj_optimizer = Curveball( input_var=self._tf_obj, predictions_fn=self._training_predictions_as_obj_fn, loss_fn=self._training_loss_fn, damping_factor=damping_factor_obj, damping_update_factor=damping_update_factor_obj, damping_update_frequency=damping_update_frequency_obj, update_cond_threshold_low=update_cond_threshold_low, update_cond_threshold_high=update_cond_threshold_high, name='obj_opt', diag_hessian_fn=loss_hessian_fn, squared_loss=squared_loss) self._optparams.obj_minimize_op = self._optparams.obj_optimizer.minimize( ) if self._probe_recons: self._optparams.probe_optimizer = Curveball( input_var=self._tf_probe, predictions_fn=self._training_predictions_as_probe_fn, loss_fn=self._training_loss_fn, damping_factor=damping_factor_probe, damping_update_factor=damping_update_factor_probe, damping_update_frequency=damping_update_frequency_probe, name='probe_opt', diag_hessian_fn=loss_hessian_fn, squared_loss=squared_loss) self._optparams.probe_minimize_op = self._optparams.probe_optimizer.minimize( ) self._optparams.training_loss_tensor = self._optparams.obj_optimizer._loss_fn_tensor self._optimizers_defined = True
def setOptimizingParams(self): if self._loss_type in ["poisson", "poisson_surrogate"]: loss_hessian_fn = self._training_loss_hessian_fn squared_loss = False else: loss_hessian_fn = None squared_loss = True with self.graph.as_default(): self._optparams.optimizer = Curveball( input_var=self._tf_var, predictions_fn=self._training_predictions_fn, loss_fn=self._training_loss_fn, name='opt', diag_hessian_fn=loss_hessian_fn, squared_loss=squared_loss) self._optparams.minimize_op = self._optparams.optimizer.minimize() self._optparams.training_loss_tensor = self._optparams.optimizer._loss_fn_tensor self._optimizers_defined = True
class CurveballOptimizer(Optimizer): def __init__(self, input_var: tf.Variable, predictions_fn: Callable, loss_fn: Callable, diag_hessian_fn: Callable = None, initial_update_delay: int = 0, update_frequency: int = 1, **extra_init_kwargs: int): super().__init__(initial_update_delay, update_frequency) self._optimizer = Curveball(input_var=input_var, predictions_fn=predictions_fn, loss_fn=loss_fn, diag_hessian_fn=diag_hessian_fn, **extra_init_kwargs) def setupMinimizeOp(self): self._minimize_op = self._optimizer.minimize() @property def minimize_op(self): return self._minimize_op