def _is_valid_pred(self, pred, raise_error=True): """ Description: checks that pred is a valid function to differentiate with respect to using jax """ if not callable(pred): if raise_error: raise error.InvalidInput( "Optimizer 'pred' input {} is not callable".format(pred)) return False inputs = list(inspect.signature(pred).parameters) if 'x' not in inputs or 'params' not in inputs: if raise_error: raise error.InvalidInput( "Optimizer 'pred' input {} must take variables named 'params' and 'x'" .format(pred)) return False try: grad_pred = grad(pred) except Exception as e: if raise_error: message = "JAX is unable to take gradient with respect to optimizer 'pred' input {}.\n".format(pred) + \ "Please verify that input is implemented using JAX NumPy. Full error message: \n{}".format(e) raise error.InvalidInput(message) return False try: jit_grad_pred = jit(grad_pred) except Exception as e: if raise_error: message = "JAX jit optimization failed on 'pred' input {}. Full error message: \n{}".format( pred, e) raise error.InvalidInput(message) return False return True
def __init__(self, pred=None, loss=mse, hyperparameters={}): self.initialized = False self.hps = {'G': 1, 'c': 1, 'D': 1, 'exp_con': 0.5} self.hps.update(hyperparameters) for key, value in self.hps.items(): if hasattr(self, key): raise error.InvalidInput( "key {} is already an attribute in {}".format(key, self)) setattr(self, key, value) # store all hyperparameters self.A, self.Ainv = None, None self.pred, self.loss = pred, loss self.numpyify = lambda m: onp.array(m).astype( onp.double) # maps jax.numpy to regular numpy if not hasattr(self, 'eta'): if 4 * self.G * self.D > self.exp_con: self.eta = 2 * self.G * self.D else: self.eta = 0.5 * self.exp_con if not hasattr(self, 'eps'): self.eps = 1.0 / ((self.eta**2) * (self.D**2)) if self._is_valid_pred(pred, raise_error=False) and self._is_valid_loss( loss, raise_error=False): self.set_predict(pred, loss=loss) @jit # partial update step for every matrix in method weights list def partial_update(A, Ainv, grad, w): A = A + np.outer(grad, grad) inv_grad = Ainv @ grad Ainv = Ainv - np.outer(inv_grad, inv_grad) / (1 + grad.T @ Ainv @ grad) new_grad = np.reshape(Ainv @ grad, w.shape) return A, Ainv, new_grad self.partial_update = partial_update
def _store_optimizer(self, optimizer, pred): if isinstance(optimizer, Optimizer): optimizer.set_predict(pred) self.optimizer = optimizer return if issubclass(optimizer, Optimizer): self.optimizer = optimizer(pred=pred) return raise error.InvalidInput("Optimizer input cannot be stored")
def __init__(self, pred=None, loss=mse, hyperparameters={}): self.initialized = False self.hps = {"T":10000, "D":1,"G":1,"c":1} self.hps.update(hyperparameters) for key, value in self.hps.items(): if hasattr(self, key): raise error.InvalidInput("key {} is already an attribute in {}".format(key, self)) setattr(self, key, value) # store all hyperparameters self.pred = pred self.loss = loss if self._is_valid_pred(pred, raise_error=False) and self._is_valid_loss(loss, raise_error=False): self.set_predict(pred, loss=loss)
def _is_valid_loss(self, loss, raise_error=True): """ Description: checks that loss is a valid function to differentiate with respect to using jax """ if not callable(loss): if raise_error: raise error.InvalidInput( "Optimizer 'loss' input {} is not callable".format(loss)) return False inputs = list(inspect.signature(loss).parameters) if len(inputs) != 2: if raise_error: raise error.InvalidInput( "Optimizer 'loss' input {} must take two arguments as input" .format(loss)) return False try: jit_grad_loss = jit(grad(loss)) except Exception as e: if raise_error: message = "JAX jit-grad failed on 'loss' input {}. Full error message: \n{}".format( loss, e) raise error.InvalidInput(message) return False return True
def __init__(self, pred=None, loss=mse, learning_rate=1.0, hyperparameters={}): self.initialized = False self.lr = learning_rate self.hyperparameters = {'max_norm': True, 'reg': 0.0} self.hyperparameters.update(hyperparameters) for key, value in self.hyperparameters.items(): if hasattr(self, key): raise error.InvalidInput( "key {} is already an attribute in {}".format(key, self)) setattr(self, key, value) # store all hyperparameters self.G = None self.pred = pred self.loss = loss if self._is_valid_pred(pred, raise_error=False) and self._is_valid_loss( loss, raise_error=False): self.set_predict(pred, loss=loss) @jit def _update(params, grad, G, max_norm): new_G = { k: (g + np.square(dw)) for (k, g), dw in zip(G.items(), grad.values()) } max_norm = np.where( max_norm, np.maximum( max_norm, np.linalg.norm( [np.linalg.norm(dw) for dw in grad.values()])), max_norm) lr = self.lr / np.where(max_norm, max_norm, 1.) new_params = { k: (w - lr * dw / np.sqrt(g)) for (k, w), dw, g in zip(params.items(), grad.values(), new_G.values()) } return new_params, new_G, max_norm self._update = _update
def __init__(self, pred=None, loss=mse, learning_rate=1.0, hyperparameters={}): self.initialized = False self.lr = learning_rate self.hps = { 'reg': 0.00, 'eps': 0.0001, 'max_norm': 0, 'project': False, 'full_matrix': False } self.hps.update(hyperparameters) self.original_max_norm = self.hps['max_norm'] for key, value in self.hps.items(): if hasattr(self, key): raise error.InvalidInput( "key {} is already an attribute in {}".format(key, self)) setattr(self, key, value) # store all hyperparameters self.A, self.Ainv = None, None self.pred, self.loss = pred, loss self.numpyify = lambda m: onp.array(m).astype( onp.double) # maps jax.numpy to regular numpy if self._is_valid_pred(pred, raise_error=False) and self._is_valid_loss( loss, raise_error=False): self.set_predict(pred, loss=loss) @jit # partial update step for every matrix in method weights list def partial_update(A, Ainv, grad, w): A = A + np.outer(grad, grad) inv_grad = Ainv @ grad Ainv = Ainv - np.outer(inv_grad, inv_grad) / (1 + grad.T @ Ainv @ grad) new_grad = np.reshape(Ainv @ grad, w.shape) return A, Ainv, new_grad self.partial_update = partial_update
def __init__(self, pred=None, loss=mse, learning_rate=1.0, hyperparameters={}): self.initialized = False self.lr = learning_rate self.hyperparameters = { 'reg': 0.00 } # L2 regularization, default value 0.01 self.hyperparameters.update(hyperparameters) for key, value in self.hyperparameters.items(): if hasattr(self, key): raise error.InvalidInput( "key {} is already an attribute in {}".format(key, self)) setattr(self, key, value) # store all hyperparameters self.pred = pred self.loss = loss if self._is_valid_pred(pred, raise_error=False) and self._is_valid_loss( loss, raise_error=False): self.set_predict(pred, loss=loss)
def __init__(self, pred=None, loss=mse, learning_rate=1.0, hyperparameters={}): self.initialized = False self.lr = learning_rate self.hyperparameters = {'T': 0, 'max_norm': True} self.hyperparameters.update(hyperparameters) self.original_max_norm = self.hyperparameters['max_norm'] for key, value in self.hyperparameters.items(): if hasattr(self, key): raise error.InvalidInput( "key {} is already an attribute in {}".format(key, self)) setattr(self, key, value) # store all hyperparameters self.G = None self.pred = pred self.loss = loss if self._is_valid_pred(pred, raise_error=False) and self._is_valid_loss( loss, raise_error=False): self.set_predict(pred, loss=loss)
def __init__(self, pred=None, loss=mse, learning_rate=1.0, hyperparameters={}): self.initialized = False self.lr = learning_rate self.hyperparameters = { 'reg': 0.0, 'beta_1': 0.9, 'beta_2': 0.999, 'eps': 1e-7, 'max_norm': True } self.hyperparameters.update(hyperparameters) for key, value in self.hyperparameters.items(): if hasattr(self, key): raise error.InvalidInput( "key {} is already an attribute in {}".format(key, self)) setattr(self, key, value) # store all hyperparameters self.beta_1_t, self.beta_2_t = self.beta_1, self.beta_2 self.m, self.v = None, None self.pred = pred self.loss = loss if self._is_valid_pred(pred, raise_error=False) and self._is_valid_loss( loss, raise_error=False): self.set_predict(pred, loss=loss) @jit # helper update method def _update(params, grad, m, v, max_norm, beta_1_t, beta_2_t): new_m = { k: self.beta_1 * m_i + (1. - self.beta_1) * dw for (k, m_i), dw in zip(m.items(), grad.values()) } new_v = { k: self.beta_2 * v_i + (1. - self.beta_2) * np.square(dw) for (k, v_i), dw in zip(v.items(), grad.values()) } m_t = [m_i / (1 - beta_1_t) for m_i in new_m] # bias-corrected estimates v_t = [v_i / (1 - beta_2_t) for v_i in new_v] # maintain current power of betas beta_1_t, beta_2_t = beta_1_t * self.beta_1, beta_2_t * self.beta_2 max_norm = np.where( max_norm, np.maximum( max_norm, np.linalg.norm( [np.linalg.norm(dw) for dw in grad.values()])), max_norm) lr = self.lr / np.where(max_norm, max_norm, 1.) new_params = { k: (w - lr * m_i / (np.sqrt(v_i) + self.eps)) for (k, w), v_i, m_i in zip(params.items(), v_t, m_t) } return new_params, new_m, new_v, max_norm, beta_1_t, beta_2_t self._update = _update