Пример #1
0
  def test_grad_ob_argnums(self):
    class Test(bp.Base):
      def __init__(self):
        super(Test, self).__init__()

        self.a = bm.TrainVar(bm.ones(10))
        self.b = bm.TrainVar(bm.random.randn(10))
        self.c = bm.TrainVar(bm.random.uniform(size=10))

      def __call__(self, d):
        return bm.sum(self.a + self.b + self.c + 2 * d)

    bm.random.seed(0)

    t = Test()
    f_grad = bm.grad(t, t.vars(), argnums=0)
    var_grads, arg_grads = f_grad(bm.random.random(10))
    for g in var_grads.values(): assert (g == 1.).all()
    assert (arg_grads == 2.).all()

    t = Test()
    f_grad = bm.grad(t, t.vars(), argnums=[0])
    var_grads, arg_grads = f_grad(bm.random.random(10))
    for g in var_grads.values(): assert (g == 1.).all()
    assert (arg_grads[0] == 2.).all()

    t = Test()
    f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=0)
    arg_grads = f_grad(bm.random.random(10))
    assert (arg_grads == 2.).all()

    t = Test()
    f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=[0])
    arg_grads = f_grad(bm.random.random(10))
    assert (arg_grads[0] == 2.).all()
Пример #2
0
  def test_grad_ob1(self):
    class Test(bp.Base):
      def __init__(self):
        super(Test, self).__init__()

        self.a = bm.TrainVar(bm.ones(10))
        self.b = bm.TrainVar(bm.random.randn(10))
        self.c = bm.TrainVar(bm.random.uniform(size=10))

      def __call__(self):
        return bm.sum(self.a + self.b + self.c)

    bm.random.seed(0)

    t = Test()
    f_grad = bm.grad(t, grad_vars=t.vars())
    grads = f_grad()
    for g in grads.values(): assert (g == 1.).all()

    t = Test()
    f_grad = bm.grad(t, grad_vars=[t.a, t.b], dyn_vars=t.vars())
    grads = f_grad()
    for g in grads: assert (g == 1.).all()

    t = Test()
    f_grad = bm.grad(t, grad_vars=t.a, dyn_vars=t.vars())
    grads = f_grad()
    assert (grads == 1.).all()
Пример #3
0
  def test_grad_ob_aux_return(self):
    class Test(bp.Base):
      def __init__(self):
        super(Test, self).__init__()
        self.a = bm.TrainVar(bm.ones(10))
        self.b = bm.TrainVar(bm.random.randn(10))
        self.c = bm.TrainVar(bm.random.uniform(size=10))

      def __call__(self):
        return bm.sum(self.a + self.b + self.c), (bm.sin(100), bm.exp(0.1))

    bm.random.seed(0)
    t = Test()
    f_grad = bm.grad(t, grad_vars=[t.a, t.b], dyn_vars=t.vars(),
                     has_aux=True, return_value=True)
    grads, returns, aux = f_grad()
    for g in grads: assert (g == 1.).all()
    assert returns == bm.sum(t.a + t.b + t.c)
    assert aux[0] == bm.sin(100)
    assert aux[1] == bm.exp(0.1)

    t = Test()
    f_grad = bm.grad(t, grad_vars=t.a, dyn_vars=t.vars(),
                     has_aux=True, return_value=True)
    grads, returns, aux = f_grad()
    assert (grads == 1.).all()
    assert returns == bm.sum(t.a + t.b + t.c)
    assert aux[0] == bm.sin(100)
    assert aux[1] == bm.exp(0.1)
Пример #4
0
  def test_grad_ob_argnums_aux_return(self):
    class Test(bp.Base):
      def __init__(self):
        super(Test, self).__init__()
        self.a = bm.TrainVar(bm.ones(10))
        self.b = bm.TrainVar(bm.random.randn(10))
        self.c = bm.TrainVar(bm.random.uniform(size=10))

      def __call__(self, d):
        return bm.sum(self.a + self.b + self.c + 2 * d), (bm.sin(100), bm.exp(0.1))

    bm.random.seed(0)

    t = Test()
    f_grad = bm.grad(t, grad_vars=t.vars(), argnums=0, has_aux=True, return_value=True)
    d = bm.random.random(10)
    (var_grads, arg_grads), loss, aux = f_grad(d)
    for g in var_grads.values(): assert (g == 1.).all()
    assert (arg_grads == 2.).all()
    assert aux[0] == bm.sin(100)
    assert aux[1] == bm.exp(0.1)
    assert loss == t(d)[0]

    t = Test()
    f_grad = bm.grad(t, grad_vars=t.vars(), argnums=[0], has_aux=True, return_value=True)
    d = bm.random.random(10)
    (var_grads, arg_grads), loss, aux = f_grad(d)
    for g in var_grads.values(): assert (g == 1.).all()
    assert (arg_grads[0] == 2.).all()
    assert aux[0] == bm.sin(100)
    assert aux[1] == bm.exp(0.1)
    assert loss == t(d)[0]

    t = Test()
    f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=0, has_aux=True, return_value=True)
    d = bm.random.random(10)
    arg_grads, loss, aux = f_grad(d)
    assert (arg_grads == 2.).all()
    assert aux[0] == bm.sin(100)
    assert aux[1] == bm.exp(0.1)
    assert loss == t(d)[0]

    t = Test()
    f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=[0], has_aux=True, return_value=True)
    d = bm.random.random(10)
    arg_grads, loss, aux = f_grad(d)
    assert (arg_grads[0] == 2.).all()
    assert aux[0] == bm.sin(100)
    assert aux[1] == bm.exp(0.1)
    assert loss == t(d)[0]
Пример #5
0
  def test_grad_pure_func_aux1(self):
    def call(a, b, c):
      return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1))

    bm.random.seed(1)
    f_grad = bm.grad(call, argnums=[0, 1, 2])
    with pytest.raises(TypeError):
      f_grad(bm.ones(10), bm.random.randn(10), bm.random.uniform(size=10))
Пример #6
0
  def test_grad_pure_func_2(self):
    def call(a, b, c): return bm.sum(a + b + c)

    bm.random.seed(1)
    a = bm.ones(10)
    b = bm.random.randn(10)
    c = bm.random.uniform(size=10)
    f_grad = bm.grad(call)
    assert (f_grad(a, b, c) == 1.).all()
Пример #7
0
  def test_grad_pure_func_aux2(self):
    def call(a, b, c):
      return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1))

    bm.random.seed(1)
    f_grad = bm.grad(call, argnums=[0, 1, 2], has_aux=True)
    grads, aux = f_grad(bm.ones(10), bm.random.randn(10), bm.random.uniform(size=10))
    for g in grads: assert (g == 1.).all()
    assert aux[0] == bm.sin(100)
    assert aux[1] == bm.exp(0.1)
Пример #8
0
  def test_grad_pure_func_return1(self):
    def call(a, b, c): return bm.sum(a + b + c)

    bm.random.seed(1)
    a = bm.ones(10)
    b = bm.random.randn(10)
    c = bm.random.uniform(size=10)
    f_grad = bm.grad(call, return_value=True)
    grads, returns = f_grad(a, b, c)
    assert (grads == 1.).all()
    assert returns == bm.sum(a + b + c)
Пример #9
0
  def test_grad_pure_func_1(self):
    def call(a, b, c): return bm.sum(a + b + c)

    bm.random.seed(1)
    a = bm.ones(10)
    b = bm.random.randn(10)
    c = bm.random.uniform(size=10)
    f_grad = bm.grad(call, argnums=[0, 1, 2])
    grads = f_grad(a, b, c)

    for g in grads: assert (g == 1.).all()
Пример #10
0
  def test_grad_func_return_aux1(self):
    def call(a, b, c):
      return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1))

    bm.random.seed(1)
    a = bm.ones(10)
    b = bm.random.randn(10)
    c = bm.random.uniform(size=10)
    f_grad = bm.grad(call, return_value=True, has_aux=True)
    grads, returns, aux = f_grad(a, b, c)
    assert (grads == 1.).all()
    assert returns == bm.sum(a + b + c)
    assert aux[0] == bm.sin(100)
    assert aux[1] == bm.exp(0.1)
Пример #11
0
          num_hidden=hidden_size,
          num_output=output_size,
          num_batch=batch_size,
          dt=env.dt)

# %%
# prediction method
predict = bm.jit(net.predict, dyn_vars=net.vars())

# Adam optimizer
opt = bm.optimizers.Adam(lr=0.001, train_vars=net.train_vars().unique())

# gradient function
grad_f = bm.grad(net.loss,
                 dyn_vars=net.vars(),
                 grad_vars=net.train_vars().unique(),
                 return_value=True,
                 has_aux=True)


# training function
@bm.jit
@bm.function(nodes=(net, opt))
def train(xs, ys):
    grads, (loss, os) = grad_f(xs, ys)
    opt.update(grads)
    return loss, os


# %%
running_acc = 0
Пример #12
0
    def find_fps_with_gd_method(self,
                                candidates,
                                tolerance=1e-5,
                                num_batch=100,
                                num_opt=10000,
                                opt_setting=None):
        """Optimize fixed points with gradient descent methods.

    Parameters
    ----------
    candidates : jax.ndarray, JaxArray
      The array with the shape of (batch size, state dim) of hidden states
      of RNN to start training for fixed points.
    tolerance: float
      The loss threshold during optimization
    num_opt : int
      The maximum number of optimization.
    num_batch : int
      Print training information during optimization every so often.
    opt_setting: optional, dict
      The optimization settings.
    """

        # optimization settings
        if opt_setting is None:
            opt_method = bm.optimizers.Adam
            opt_lr = bm.optimizers.ExponentialDecay(0.2, 1, 0.9999)
            opt_setting = {
                'beta1': 0.9,
                'beta2': 0.999,
                'eps': 1e-8,
                'name': None
            }
        else:
            assert isinstance(opt_setting, dict)
            assert 'method' in opt_setting
            assert 'lr' in opt_setting
            opt_method = opt_setting.pop('method')
            if isinstance(opt_method, str):
                assert opt_method in bm.optimizers.__all__
                opt_method = getattr(bm.optimizers, opt_method)
            assert isinstance(opt_method, type)
            if bm.optimizers.Optimizer not in inspect.getmro(opt_method):
                raise ValueError
            opt_lr = opt_setting.pop('lr')
            assert isinstance(opt_lr, (int, float, bm.optimizers.Scheduler))
            opt_setting = opt_setting

        if self.verbose:
            print(
                f"Optimizing with {opt_method.__name__} to find fixed points:")

        # set up optimization
        fixed_points = bm.Variable(bm.asarray(candidates))
        grad_f = bm.grad(lambda: self.f_loss_batch(fixed_points.value).mean(),
                         grad_vars={'a': fixed_points},
                         return_value=True)
        opt = opt_method(train_vars={'a': fixed_points},
                         lr=opt_lr,
                         **opt_setting)
        dyn_vars = opt.vars() + {'_a': fixed_points}

        def train(idx):
            gradients, loss = grad_f()
            opt.update(gradients)
            return loss

        @partial(bm.jit,
                 dyn_vars=dyn_vars,
                 static_argnames=('start_i', 'num_batch'))
        def batch_train(start_i, num_batch):
            f = bm.make_loop(train, dyn_vars=dyn_vars, has_return=True)
            return f(bm.arange(start_i, start_i + num_batch))

        # Run the optimization
        opt_losses = []
        do_stop = False
        num_opt_loops = int(num_opt / num_batch)
        for oidx in range(num_opt_loops):
            if do_stop: break
            batch_idx_start = oidx * num_batch
            start_time = time.time()
            (_, losses) = batch_train(start_i=batch_idx_start,
                                      num_batch=num_batch)
            batch_time = time.time() - start_time
            opt_losses.append(losses)

            if self.verbose:
                print(
                    f"    "
                    f"Batches {batch_idx_start + 1}-{batch_idx_start + num_batch} "
                    f"in {batch_time:0.2f} sec, Training loss {losses[-1]:0.10f}"
                )

            if losses[-1] < tolerance:
                do_stop = True
                if self.verbose:
                    print(
                        f'    '
                        f'Stop optimization as mean training loss {losses[-1]:0.10f} '
                        f'is below tolerance {tolerance:0.10f}.')
        self.opt_losses = bm.concatenate(opt_losses)
        self._losses = np.asarray(self.f_loss_batch(fixed_points))
        self._fixed_points = np.asarray(fixed_points)
        self._selected_ids = np.arange(fixed_points.shape[0])