Ejemplo n.º 1
0
 def _is_valid_pred(self, pred, raise_error=True):
     """ Description: checks that pred is a valid function to differentiate with respect to using jax """
     if not callable(pred):
         if raise_error:
             raise error.InvalidInput(
                 "Optimizer 'pred' input {} is not callable".format(pred))
         return False
     inputs = list(inspect.signature(pred).parameters)
     if 'x' not in inputs or 'params' not in inputs:
         if raise_error:
             raise error.InvalidInput(
                 "Optimizer 'pred' input {} must take variables named 'params' and 'x'"
                 .format(pred))
         return False
     try:
         grad_pred = grad(pred)
     except Exception as e:
         if raise_error:
             message = "JAX is unable to take gradient with respect to optimizer 'pred' input {}.\n".format(pred) + \
             "Please verify that input is implemented using JAX NumPy. Full error message: \n{}".format(e)
             raise error.InvalidInput(message)
         return False
     try:
         jit_grad_pred = jit(grad_pred)
     except Exception as e:
         if raise_error:
             message = "JAX jit optimization failed on 'pred' input {}. Full error message: \n{}".format(
                 pred, e)
             raise error.InvalidInput(message)
         return False
     return True
Ejemplo n.º 2
0
    def __init__(self, pred=None, loss=mse, hyperparameters={}):
        self.initialized = False
        self.hps = {'G': 1, 'c': 1, 'D': 1, 'exp_con': 0.5}
        self.hps.update(hyperparameters)
        for key, value in self.hps.items():
            if hasattr(self, key):
                raise error.InvalidInput(
                    "key {} is already an attribute in {}".format(key, self))
            setattr(self, key, value)  # store all hyperparameters
        self.A, self.Ainv = None, None
        self.pred, self.loss = pred, loss
        self.numpyify = lambda m: onp.array(m).astype(
            onp.double)  # maps jax.numpy to regular numpy
        if not hasattr(self, 'eta'):
            if 4 * self.G * self.D > self.exp_con:
                self.eta = 2 * self.G * self.D
            else:
                self.eta = 0.5 * self.exp_con
        if not hasattr(self, 'eps'):
            self.eps = 1.0 / ((self.eta**2) * (self.D**2))
        if self._is_valid_pred(pred,
                               raise_error=False) and self._is_valid_loss(
                                   loss, raise_error=False):
            self.set_predict(pred, loss=loss)

        @jit  # partial update step for every matrix in method weights list
        def partial_update(A, Ainv, grad, w):
            A = A + np.outer(grad, grad)
            inv_grad = Ainv @ grad
            Ainv = Ainv - np.outer(inv_grad,
                                   inv_grad) / (1 + grad.T @ Ainv @ grad)
            new_grad = np.reshape(Ainv @ grad, w.shape)
            return A, Ainv, new_grad

        self.partial_update = partial_update
Ejemplo n.º 3
0
 def _store_optimizer(self, optimizer, pred):
     if isinstance(optimizer, Optimizer):
         optimizer.set_predict(pred)
         self.optimizer = optimizer
         return
     if issubclass(optimizer, Optimizer):
         self.optimizer = optimizer(pred=pred)
         return
     raise error.InvalidInput("Optimizer input cannot be stored")
 def __init__(self, pred=None, loss=mse, hyperparameters={}):
     self.initialized = False
     self.hps = {"T":10000, "D":1,"G":1,"c":1}
     self.hps.update(hyperparameters)
     for key, value in self.hps.items():
         if hasattr(self, key):
             raise error.InvalidInput("key {} is already an attribute in {}".format(key, self))
         setattr(self, key, value) # store all hyperparameters
     self.pred = pred
     self.loss = loss
     if self._is_valid_pred(pred, raise_error=False) and self._is_valid_loss(loss, raise_error=False):
         self.set_predict(pred, loss=loss)
Ejemplo n.º 5
0
 def _is_valid_loss(self, loss, raise_error=True):
     """ Description: checks that loss is a valid function to differentiate with respect to using jax """
     if not callable(loss):
         if raise_error:
             raise error.InvalidInput(
                 "Optimizer 'loss' input {} is not callable".format(loss))
         return False
     inputs = list(inspect.signature(loss).parameters)
     if len(inputs) != 2:
         if raise_error:
             raise error.InvalidInput(
                 "Optimizer 'loss' input {} must take two arguments as input"
                 .format(loss))
         return False
     try:
         jit_grad_loss = jit(grad(loss))
     except Exception as e:
         if raise_error:
             message = "JAX jit-grad failed on 'loss' input {}. Full error message: \n{}".format(
                 loss, e)
             raise error.InvalidInput(message)
         return False
     return True
Ejemplo n.º 6
0
    def __init__(self,
                 pred=None,
                 loss=mse,
                 learning_rate=1.0,
                 hyperparameters={}):
        self.initialized = False
        self.lr = learning_rate
        self.hyperparameters = {'max_norm': True, 'reg': 0.0}
        self.hyperparameters.update(hyperparameters)
        for key, value in self.hyperparameters.items():
            if hasattr(self, key):
                raise error.InvalidInput(
                    "key {} is already an attribute in {}".format(key, self))
            setattr(self, key, value)  # store all hyperparameters
        self.G = None
        self.pred = pred
        self.loss = loss
        if self._is_valid_pred(pred,
                               raise_error=False) and self._is_valid_loss(
                                   loss, raise_error=False):
            self.set_predict(pred, loss=loss)

        @jit
        def _update(params, grad, G, max_norm):
            new_G = {
                k: (g + np.square(dw))
                for (k, g), dw in zip(G.items(), grad.values())
            }
            max_norm = np.where(
                max_norm,
                np.maximum(
                    max_norm,
                    np.linalg.norm(
                        [np.linalg.norm(dw) for dw in grad.values()])),
                max_norm)
            lr = self.lr / np.where(max_norm, max_norm, 1.)
            new_params = {
                k: (w - lr * dw / np.sqrt(g))
                for (k, w), dw, g in zip(params.items(), grad.values(),
                                         new_G.values())
            }
            return new_params, new_G, max_norm

        self._update = _update
Ejemplo n.º 7
0
    def __init__(self,
                 pred=None,
                 loss=mse,
                 learning_rate=1.0,
                 hyperparameters={}):
        self.initialized = False
        self.lr = learning_rate
        self.hps = {
            'reg': 0.00,
            'eps': 0.0001,
            'max_norm': 0,
            'project': False,
            'full_matrix': False
        }
        self.hps.update(hyperparameters)
        self.original_max_norm = self.hps['max_norm']
        for key, value in self.hps.items():
            if hasattr(self, key):
                raise error.InvalidInput(
                    "key {} is already an attribute in {}".format(key, self))
            setattr(self, key, value)  # store all hyperparameters
        self.A, self.Ainv = None, None
        self.pred, self.loss = pred, loss
        self.numpyify = lambda m: onp.array(m).astype(
            onp.double)  # maps jax.numpy to regular numpy

        if self._is_valid_pred(pred,
                               raise_error=False) and self._is_valid_loss(
                                   loss, raise_error=False):
            self.set_predict(pred, loss=loss)

        @jit  # partial update step for every matrix in method weights list
        def partial_update(A, Ainv, grad, w):
            A = A + np.outer(grad, grad)
            inv_grad = Ainv @ grad
            Ainv = Ainv - np.outer(inv_grad,
                                   inv_grad) / (1 + grad.T @ Ainv @ grad)
            new_grad = np.reshape(Ainv @ grad, w.shape)
            return A, Ainv, new_grad

        self.partial_update = partial_update
Ejemplo n.º 8
0
 def __init__(self,
              pred=None,
              loss=mse,
              learning_rate=1.0,
              hyperparameters={}):
     self.initialized = False
     self.lr = learning_rate
     self.hyperparameters = {
         'reg': 0.00
     }  # L2 regularization, default value 0.01
     self.hyperparameters.update(hyperparameters)
     for key, value in self.hyperparameters.items():
         if hasattr(self, key):
             raise error.InvalidInput(
                 "key {} is already an attribute in {}".format(key, self))
         setattr(self, key, value)  # store all hyperparameters
     self.pred = pred
     self.loss = loss
     if self._is_valid_pred(pred,
                            raise_error=False) and self._is_valid_loss(
                                loss, raise_error=False):
         self.set_predict(pred, loss=loss)
Ejemplo n.º 9
0
 def __init__(self,
              pred=None,
              loss=mse,
              learning_rate=1.0,
              hyperparameters={}):
     self.initialized = False
     self.lr = learning_rate
     self.hyperparameters = {'T': 0, 'max_norm': True}
     self.hyperparameters.update(hyperparameters)
     self.original_max_norm = self.hyperparameters['max_norm']
     for key, value in self.hyperparameters.items():
         if hasattr(self, key):
             raise error.InvalidInput(
                 "key {} is already an attribute in {}".format(key, self))
         setattr(self, key, value)  # store all hyperparameters
     self.G = None
     self.pred = pred
     self.loss = loss
     if self._is_valid_pred(pred,
                            raise_error=False) and self._is_valid_loss(
                                loss, raise_error=False):
         self.set_predict(pred, loss=loss)
Ejemplo n.º 10
0
    def __init__(self,
                 pred=None,
                 loss=mse,
                 learning_rate=1.0,
                 hyperparameters={}):
        self.initialized = False
        self.lr = learning_rate
        self.hyperparameters = {
            'reg': 0.0,
            'beta_1': 0.9,
            'beta_2': 0.999,
            'eps': 1e-7,
            'max_norm': True
        }
        self.hyperparameters.update(hyperparameters)
        for key, value in self.hyperparameters.items():
            if hasattr(self, key):
                raise error.InvalidInput(
                    "key {} is already an attribute in {}".format(key, self))
            setattr(self, key, value)  # store all hyperparameters
        self.beta_1_t, self.beta_2_t = self.beta_1, self.beta_2
        self.m, self.v = None, None

        self.pred = pred
        self.loss = loss
        if self._is_valid_pred(pred,
                               raise_error=False) and self._is_valid_loss(
                                   loss, raise_error=False):
            self.set_predict(pred, loss=loss)

        @jit  # helper update method
        def _update(params, grad, m, v, max_norm, beta_1_t, beta_2_t):
            new_m = {
                k: self.beta_1 * m_i + (1. - self.beta_1) * dw
                for (k, m_i), dw in zip(m.items(), grad.values())
            }
            new_v = {
                k: self.beta_2 * v_i + (1. - self.beta_2) * np.square(dw)
                for (k, v_i), dw in zip(v.items(), grad.values())
            }
            m_t = [m_i / (1 - beta_1_t)
                   for m_i in new_m]  # bias-corrected estimates
            v_t = [v_i / (1 - beta_2_t) for v_i in new_v]

            # maintain current power of betas
            beta_1_t, beta_2_t = beta_1_t * self.beta_1, beta_2_t * self.beta_2
            max_norm = np.where(
                max_norm,
                np.maximum(
                    max_norm,
                    np.linalg.norm(
                        [np.linalg.norm(dw) for dw in grad.values()])),
                max_norm)
            lr = self.lr / np.where(max_norm, max_norm, 1.)
            new_params = {
                k: (w - lr * m_i / (np.sqrt(v_i) + self.eps))
                for (k, w), v_i, m_i in zip(params.items(), v_t, m_t)
            }
            return new_params, new_m, new_v, max_norm, beta_1_t, beta_2_t

        self._update = _update