def test_grad_ob_argnums(self): class Test(bp.Base): def __init__(self): super(Test, self).__init__() self.a = bm.TrainVar(bm.ones(10)) self.b = bm.TrainVar(bm.random.randn(10)) self.c = bm.TrainVar(bm.random.uniform(size=10)) def __call__(self, d): return bm.sum(self.a + self.b + self.c + 2 * d) bm.random.seed(0) t = Test() f_grad = bm.grad(t, t.vars(), argnums=0) var_grads, arg_grads = f_grad(bm.random.random(10)) for g in var_grads.values(): assert (g == 1.).all() assert (arg_grads == 2.).all() t = Test() f_grad = bm.grad(t, t.vars(), argnums=[0]) var_grads, arg_grads = f_grad(bm.random.random(10)) for g in var_grads.values(): assert (g == 1.).all() assert (arg_grads[0] == 2.).all() t = Test() f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=0) arg_grads = f_grad(bm.random.random(10)) assert (arg_grads == 2.).all() t = Test() f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=[0]) arg_grads = f_grad(bm.random.random(10)) assert (arg_grads[0] == 2.).all()
def test_grad_ob1(self): class Test(bp.Base): def __init__(self): super(Test, self).__init__() self.a = bm.TrainVar(bm.ones(10)) self.b = bm.TrainVar(bm.random.randn(10)) self.c = bm.TrainVar(bm.random.uniform(size=10)) def __call__(self): return bm.sum(self.a + self.b + self.c) bm.random.seed(0) t = Test() f_grad = bm.grad(t, grad_vars=t.vars()) grads = f_grad() for g in grads.values(): assert (g == 1.).all() t = Test() f_grad = bm.grad(t, grad_vars=[t.a, t.b], dyn_vars=t.vars()) grads = f_grad() for g in grads: assert (g == 1.).all() t = Test() f_grad = bm.grad(t, grad_vars=t.a, dyn_vars=t.vars()) grads = f_grad() assert (grads == 1.).all()
def test_grad_ob_aux_return(self): class Test(bp.Base): def __init__(self): super(Test, self).__init__() self.a = bm.TrainVar(bm.ones(10)) self.b = bm.TrainVar(bm.random.randn(10)) self.c = bm.TrainVar(bm.random.uniform(size=10)) def __call__(self): return bm.sum(self.a + self.b + self.c), (bm.sin(100), bm.exp(0.1)) bm.random.seed(0) t = Test() f_grad = bm.grad(t, grad_vars=[t.a, t.b], dyn_vars=t.vars(), has_aux=True, return_value=True) grads, returns, aux = f_grad() for g in grads: assert (g == 1.).all() assert returns == bm.sum(t.a + t.b + t.c) assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1) t = Test() f_grad = bm.grad(t, grad_vars=t.a, dyn_vars=t.vars(), has_aux=True, return_value=True) grads, returns, aux = f_grad() assert (grads == 1.).all() assert returns == bm.sum(t.a + t.b + t.c) assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1)
def test_grad_ob_argnums_aux_return(self): class Test(bp.Base): def __init__(self): super(Test, self).__init__() self.a = bm.TrainVar(bm.ones(10)) self.b = bm.TrainVar(bm.random.randn(10)) self.c = bm.TrainVar(bm.random.uniform(size=10)) def __call__(self, d): return bm.sum(self.a + self.b + self.c + 2 * d), (bm.sin(100), bm.exp(0.1)) bm.random.seed(0) t = Test() f_grad = bm.grad(t, grad_vars=t.vars(), argnums=0, has_aux=True, return_value=True) d = bm.random.random(10) (var_grads, arg_grads), loss, aux = f_grad(d) for g in var_grads.values(): assert (g == 1.).all() assert (arg_grads == 2.).all() assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1) assert loss == t(d)[0] t = Test() f_grad = bm.grad(t, grad_vars=t.vars(), argnums=[0], has_aux=True, return_value=True) d = bm.random.random(10) (var_grads, arg_grads), loss, aux = f_grad(d) for g in var_grads.values(): assert (g == 1.).all() assert (arg_grads[0] == 2.).all() assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1) assert loss == t(d)[0] t = Test() f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=0, has_aux=True, return_value=True) d = bm.random.random(10) arg_grads, loss, aux = f_grad(d) assert (arg_grads == 2.).all() assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1) assert loss == t(d)[0] t = Test() f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=[0], has_aux=True, return_value=True) d = bm.random.random(10) arg_grads, loss, aux = f_grad(d) assert (arg_grads[0] == 2.).all() assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1) assert loss == t(d)[0]
def test_grad_pure_func_aux1(self): def call(a, b, c): return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1)) bm.random.seed(1) f_grad = bm.grad(call, argnums=[0, 1, 2]) with pytest.raises(TypeError): f_grad(bm.ones(10), bm.random.randn(10), bm.random.uniform(size=10))
def test_grad_pure_func_2(self): def call(a, b, c): return bm.sum(a + b + c) bm.random.seed(1) a = bm.ones(10) b = bm.random.randn(10) c = bm.random.uniform(size=10) f_grad = bm.grad(call) assert (f_grad(a, b, c) == 1.).all()
def test_grad_pure_func_aux2(self): def call(a, b, c): return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1)) bm.random.seed(1) f_grad = bm.grad(call, argnums=[0, 1, 2], has_aux=True) grads, aux = f_grad(bm.ones(10), bm.random.randn(10), bm.random.uniform(size=10)) for g in grads: assert (g == 1.).all() assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1)
def test_grad_pure_func_return1(self): def call(a, b, c): return bm.sum(a + b + c) bm.random.seed(1) a = bm.ones(10) b = bm.random.randn(10) c = bm.random.uniform(size=10) f_grad = bm.grad(call, return_value=True) grads, returns = f_grad(a, b, c) assert (grads == 1.).all() assert returns == bm.sum(a + b + c)
def test_grad_pure_func_1(self): def call(a, b, c): return bm.sum(a + b + c) bm.random.seed(1) a = bm.ones(10) b = bm.random.randn(10) c = bm.random.uniform(size=10) f_grad = bm.grad(call, argnums=[0, 1, 2]) grads = f_grad(a, b, c) for g in grads: assert (g == 1.).all()
def test_grad_func_return_aux1(self): def call(a, b, c): return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1)) bm.random.seed(1) a = bm.ones(10) b = bm.random.randn(10) c = bm.random.uniform(size=10) f_grad = bm.grad(call, return_value=True, has_aux=True) grads, returns, aux = f_grad(a, b, c) assert (grads == 1.).all() assert returns == bm.sum(a + b + c) assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1)
num_hidden=hidden_size, num_output=output_size, num_batch=batch_size, dt=env.dt) # %% # prediction method predict = bm.jit(net.predict, dyn_vars=net.vars()) # Adam optimizer opt = bm.optimizers.Adam(lr=0.001, train_vars=net.train_vars().unique()) # gradient function grad_f = bm.grad(net.loss, dyn_vars=net.vars(), grad_vars=net.train_vars().unique(), return_value=True, has_aux=True) # training function @bm.jit @bm.function(nodes=(net, opt)) def train(xs, ys): grads, (loss, os) = grad_f(xs, ys) opt.update(grads) return loss, os # %% running_acc = 0
def find_fps_with_gd_method(self, candidates, tolerance=1e-5, num_batch=100, num_opt=10000, opt_setting=None): """Optimize fixed points with gradient descent methods. Parameters ---------- candidates : jax.ndarray, JaxArray The array with the shape of (batch size, state dim) of hidden states of RNN to start training for fixed points. tolerance: float The loss threshold during optimization num_opt : int The maximum number of optimization. num_batch : int Print training information during optimization every so often. opt_setting: optional, dict The optimization settings. """ # optimization settings if opt_setting is None: opt_method = bm.optimizers.Adam opt_lr = bm.optimizers.ExponentialDecay(0.2, 1, 0.9999) opt_setting = { 'beta1': 0.9, 'beta2': 0.999, 'eps': 1e-8, 'name': None } else: assert isinstance(opt_setting, dict) assert 'method' in opt_setting assert 'lr' in opt_setting opt_method = opt_setting.pop('method') if isinstance(opt_method, str): assert opt_method in bm.optimizers.__all__ opt_method = getattr(bm.optimizers, opt_method) assert isinstance(opt_method, type) if bm.optimizers.Optimizer not in inspect.getmro(opt_method): raise ValueError opt_lr = opt_setting.pop('lr') assert isinstance(opt_lr, (int, float, bm.optimizers.Scheduler)) opt_setting = opt_setting if self.verbose: print( f"Optimizing with {opt_method.__name__} to find fixed points:") # set up optimization fixed_points = bm.Variable(bm.asarray(candidates)) grad_f = bm.grad(lambda: self.f_loss_batch(fixed_points.value).mean(), grad_vars={'a': fixed_points}, return_value=True) opt = opt_method(train_vars={'a': fixed_points}, lr=opt_lr, **opt_setting) dyn_vars = opt.vars() + {'_a': fixed_points} def train(idx): gradients, loss = grad_f() opt.update(gradients) return loss @partial(bm.jit, dyn_vars=dyn_vars, static_argnames=('start_i', 'num_batch')) def batch_train(start_i, num_batch): f = bm.make_loop(train, dyn_vars=dyn_vars, has_return=True) return f(bm.arange(start_i, start_i + num_batch)) # Run the optimization opt_losses = [] do_stop = False num_opt_loops = int(num_opt / num_batch) for oidx in range(num_opt_loops): if do_stop: break batch_idx_start = oidx * num_batch start_time = time.time() (_, losses) = batch_train(start_i=batch_idx_start, num_batch=num_batch) batch_time = time.time() - start_time opt_losses.append(losses) if self.verbose: print( f" " f"Batches {batch_idx_start + 1}-{batch_idx_start + num_batch} " f"in {batch_time:0.2f} sec, Training loss {losses[-1]:0.10f}" ) if losses[-1] < tolerance: do_stop = True if self.verbose: print( f' ' f'Stop optimization as mean training loss {losses[-1]:0.10f} ' f'is below tolerance {tolerance:0.10f}.') self.opt_losses = bm.concatenate(opt_losses) self._losses = np.asarray(self.f_loss_batch(fixed_points)) self._fixed_points = np.asarray(fixed_points) self._selected_ids = np.arange(fixed_points.shape[0])