예제 #1
0
파일: minmax_test.py 프로젝트: swyl/jax
    def testTracedStepSize(self):
        def loss(x, _):
            return np.dot(x, x)

        x0 = np.ones(2)
        num_iters = 100
        step_size = 0.1

        init_fun, _ = minmax.sgd(step_size)
        opt_state = init_fun(x0)

        @jit
        def update(opt_state, step_size):
            _, update_fun = minmax.sgd(step_size)
            x = minmax.get_params(opt_state)
            g = grad(loss)(x, None)
            return update_fun(0, g, opt_state)

        update(opt_state, 0.9)  # doesn't crash
예제 #2
0
파일: minmax_test.py 프로젝트: swyl/jax
 def update(opt_state, step_size):
     _, update_fun = minmax.sgd(step_size)
     x = minmax.get_params(opt_state)
     g = grad(loss)(x, None)
     return update_fun(0, g, opt_state)
예제 #3
0
print("Trained loss1: {:0.2f}".format(loss(weights_final1, train_data)))

################
# Full batch SGD using optim library
#import jax.experimental.optimizers
import jax.experimental.minmax as optimizers


@jit
def step(i, opt_state):
    weights = optimizers.get_params(opt_state)
    g = train_gradient_fun(weights)
    return opt_update(i, g, opt_state)


opt_init, opt_update = optimizers.sgd(step_size=lr)
opt_state = opt_init(weights_init)
print("jax SGD")
for i in range(max_iter):
    opt_state = step(i, opt_state)
weights_final2 = optimizers.get_params(opt_state)
print("Trained loss2: {:0.2f}".format(loss(weights_final2, train_data)))

################
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html
from scipy.optimize import minimize

lbfgs_memory = 10
opts = {"maxiter": max_iter, "maxcor": lbfgs_memory}
iter = itertools.count()