def testTracedStepSize(self): def loss(x, _): return np.dot(x, x) x0 = np.ones(2) num_iters = 100 step_size = 0.1 init_fun, _ = minmax.sgd(step_size) opt_state = init_fun(x0) @jit def update(opt_state, step_size): _, update_fun = minmax.sgd(step_size) x = minmax.get_params(opt_state) g = grad(loss)(x, None) return update_fun(0, g, opt_state) update(opt_state, 0.9) # doesn't crash
def update(opt_state, step_size): _, update_fun = minmax.sgd(step_size) x = minmax.get_params(opt_state) g = grad(loss)(x, None) return update_fun(0, g, opt_state)
print("Trained loss1: {:0.2f}".format(loss(weights_final1, train_data))) ################ # Full batch SGD using optim library #import jax.experimental.optimizers import jax.experimental.minmax as optimizers @jit def step(i, opt_state): weights = optimizers.get_params(opt_state) g = train_gradient_fun(weights) return opt_update(i, g, opt_state) opt_init, opt_update = optimizers.sgd(step_size=lr) opt_state = opt_init(weights_init) print("jax SGD") for i in range(max_iter): opt_state = step(i, opt_state) weights_final2 = optimizers.get_params(opt_state) print("Trained loss2: {:0.2f}".format(loss(weights_final2, train_data))) ################ # https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html from scipy.optimize import minimize lbfgs_memory = 10 opts = {"maxiter": max_iter, "maxcor": lbfgs_memory} iter = itertools.count()