示例#1
0
def madgrad(step_size=0.01, momentum=0.9, epsilon=1.0e-6):
    """
    Implementation of MADGRAD (Defazio and Jelassi, arXiv:2101.11075)
    """
    step_size = make_schedule(step_size)
    momentum = make_schedule(momentum)

    def init(x):
        s = jnp.zeros_like(x)
        nu = jnp.zeros_like(x)
        x0 = x
        return x, s, nu, x0

    def update(i, g, state):
        x, s, nu, x0 = state
        lbda = step_size(i) * jnp.sqrt(i + 1)
        s = s + lbda * g
        nu = nu + lbda * g * g
        z = x0 - s / (jnp.power(nu, 1.0 / 3.0) + epsilon)
        x = (1 - momentum(i)) * x + momentum(i) * z
        return x, s, nu, x0

    def get_params(state):
        x, s, nu, x0 = state
        return x

    return Optimizer(init, update, get_params)
示例#2
0
文件: cga.py 项目: niklasschmitz/fax
def full_solve_cga(step_size_f, step_size_g, f, g):
    """CGA using a naive implementation which build the full hessians."""
    step_size_f = optimizers.make_schedule(step_size_f)
    step_size_g = optimizers.make_schedule(step_size_g)

    def init(inputs):
        return CGAState(
            x=inputs[0],
            y=inputs[1],
            delta_x=np.zeros_like(inputs[0]),
            delta_y=np.zeros_like(inputs[1]),
        )

    def update(i, grads, inputs, *args, **kwargs):
        if len(inputs) < 4:
            x, y = inputs
            delta_x = None
            delta_y = None
        else:
            x, y, delta_x, delta_y = inputs

        grad_xf, grad_yg = grads
        eta_f = step_size_f(i)
        eta_g = step_size_g(i)

        Dxyf = make_mixed_hessian(partial(f, *args, **kwargs), 0, 1)(x, y)
        Dyxg = make_mixed_hessian(partial(g, *args, **kwargs), 1, 0)(x, y)

        bx = grad_xf + eta_f * np.dot(Dxyf, grad_yg)
        delta_x = np.linalg.solve(
            np.eye(x.shape[0]) - eta_f**2 * np.dot(Dxyf, Dyxg),
            bx,
        )

        by = grad_yg + eta_g * np.dot(Dyxg, grad_xf)
        delta_y = np.linalg.solve(
            np.eye(y.shape[0]) - eta_g**2 * np.dot(Dyxg, Dxyf),
            by,
        )

        x = x + eta_f * delta_x
        y = y + eta_g * delta_y
        return CGAState(x, y, delta_x, delta_y)

    def get_params(state):
        return state[:2]

    return init, update, get_params
示例#3
0
def ngd_cg(step_size, b1=0.9, b2=0.999, eps=1e-8, lmda=0.001, decay=0.9):
    """Construct optimizer triple for Adam.
        Args:
        step_size: positive scalar, or a callable representing a step size schedule
          that maps the iteration index to positive scalar.
        b1: optional, a positive scalar value for beta_1, the exponential decay rate
          for the first moment estimates (default 0.9).
        b2: optional, a positive scalar value for beta_2, the exponential decay rate
          for the second moment estimates (default 0.999).
        eps: optional, a positive scalar value for epsilon, a small constant for
          numerical stability (default 1e-8).
        Returns:
        An (init_fun, update_fun, get_params) triple.
    """
    step_size = make_schedule(step_size)
    def init(x0):
        return x0,
    def update(i, g, state):
        # Get gradients
        # Solve cg
        ng = cg_solve(Fvp_fn, g)

        # compute step size based on stats
        lr = step_size(i)
        alpha = np.sqrt(np.abs(lr / (np.dot(g, ng) + 1e-20)))

        # update params
        x = x - alpha * ng
        return x
    def get_params(state):
        x, = state
        return x
    return init, update, get_params
示例#4
0
def adam_custom(step_size, b1=0.9, b2=0.999, eps=1e-8):
  """Construct optimizer triple for Adam.
  Args:
    step_size: positive scalar, or a callable representing a step size schedule
      that maps the iteration index to positive scalar.
    b1: optional, a positive scalar value for beta_1, the exponential decay rate
      for the first moment estimates (default 0.9).
    b2: optional, a positive scalar value for beta_2, the exponential decay rate
      for the second moment estimates (default 0.999).
    eps: optional, a positive scalar value for epsilon, a small constant for
      numerical stability (default 1e-8).
  Returns:
    An (init_fun, update_fun, get_params) triple.
  """
  step_size = make_schedule(step_size)
  def init(x0):
    m0 = np.zeros_like(x0)
    v0 = np.zeros_like(x0)
    return x0, m0, v0
  def update(i, g, state):
    x_step, m, v = state
    m = (1 - b1) * g + b1 * m  # First  moment estimate.
    v = (1 - b2) * np.square(g) + b2 * v  # Second moment estimate.
    mhat = m / (1 - b1 ** (i + 1))  # Bias correction.
    vhat = v / (1 - b2 ** (i + 1))
    x_step = step_size(i) * mhat / (np.sqrt(vhat) + eps)
    return x_step, m, v
  def get_params(state):
    x_step, _, _ = state
    return x_step
  return init, update, get_params
示例#5
0
def sgd(step_size):
    step_size = jax_opt.make_schedule(step_size)

    def init(x0):
        return copy.deepcopy(x0)

    def update(i, g, x):
        return x - step_size(i) * g

    def get_params(x):
        return x

    return init, update, get_params
示例#6
0
def momentum(learning_rate, momentum=0.9):
    """A standard momentum optimizer for testing.
  """
    learning_rate = opt.make_schedule(learning_rate)

    def init_fun(x0):
        v0 = np.zeros_like(x0)
        return x0, v0

    def update_fun(i, g, x, velocity):
        velocity = momentum * velocity + g
        x = x - learning_rate(i) * velocity
        return x, velocity

    return init_fun, update_fun
def momentum(learning_rate, momentum=0.9):
  """A standard momentum optimizer for testing.

  Different from `jax.experimental.optimizers.momentum` (Nesterov).
  """
  learning_rate = opt.make_schedule(learning_rate)
  def init_fun(x0):
    v0 = np.zeros_like(x0)
    return x0, v0
  def update_fun(i, g, state):
    x, velocity = state
    velocity = momentum * velocity + g
    x = x - learning_rate(i) * velocity
    return x, velocity
  def get_params(state):
    x, _ = state
    return x
  return init_fun, update_fun, get_params
示例#8
0
def momentum(step_size, mass, weight_decay=0.):
    step_size = jax_opt.make_schedule(step_size)

    def init(x0):
        v0 = np.zeros_like(x0)
        return x0, v0

    def update(i, g, state):
        x, velocity = state
        if weight_decay != 0.:
            g = g + weight_decay * x
        velocity = mass * velocity + g
        x = x - step_size(i) * velocity
        return x, velocity

    def get_params(state):
        x, _ = state
        return x

    return init, update, get_params
示例#9
0
def adamW(step_size, b1=0.9, b2=0.999, eps=1e-8, w=0.01):
    """Construct optimizer triple for Adam.

    This docstring is different from the rest because we want to submit this
    to the jax library, so DON'T CHANGE IT TO SPHINX-STYLE!

    Args:
        step_size: positive scalar, or a callable representing a step size schedule
            that maps the iteration index to positive scalar.
        b1: optional, a positive scalar value for beta_1, the exponential decay rate
            for the first moment estimates (default 0.9).
        b2: optional, a positive scalar value for beta_2, the exponential decay rate
            for the second moment estimates (default 0.999).
        eps: optional, a positive scalar value for epsilon, a small constant for
            numerical stability (default 1e-8).
        w: optional, weight decay term (default 0.01)

    Returns:
        An (init_fun, update_fun, get_params) triple.
    """
    step_size = make_schedule(step_size)

    def init(x0):
        m0 = np.zeros_like(x0)
        v0 = np.zeros_like(x0)
        return x0, m0, v0

    def update(i, g, state):
        x, m, v = state
        m = (1 - b1) * g + b1 * m  # First  moment estimate.
        v = (1 - b2) * (g ** 2) + b2 * v  # Second moment estimate.
        mhat = m / (1 - b1 ** (i + 1))  # Bias correction.
        vhat = v / (1 - b2 ** (i + 1))
        x = x - step_size(i) * (mhat / (np.sqrt(vhat) + eps) + w * x)
        return x, m, v

    def get_params(state):
        x, m, v = state
        return x

    return init, update, get_params
示例#10
0
def adahessian(step_size=1e-1,
               b1=0.9,
               b2=0.999,
               eps=1e-8,
               weight_decay=0.0,
               hessian_power=1):
    """Construct optimizer triple for AdaHessian.
        Args:
        step_size: positive scalar, or a callable representing a step size schedule that maps the iteration index to positive scalar.
        b1: optional, a positive scalar value for beta_1, the exponential decay rate for the first moment estimates (default 0.9).
        b2: optional, a positive scalar value for beta_2, the exponential decay rate for the second moment estimates (default 0.999).
        eps: optional, a positive scalar value for epsilon, a small constant for numerical stability (default 1e-4).
        weight_decay: optional, weight decay (L2 penalty) (default 0).
        hessian_power: optional, Hessian power (default 1)
        Returns:
        An (init_fun, update_fun, get_params) triple.
    """
    step_size = make_schedule(step_size)

    def init(x0):
        m0 = jnp.zeros_like(x0)
        v0 = jnp.zeros_like(x0)
        return x0, m0, v0

    def update(i, g, h, state):
        x, m, v = state
        h = average_magnitude(h)
        m = (1 - b1) * g + b1 * m  # First moment estimate.
        v = (1 - b2) * jnp.square(
            h) + b2 * v  # Second moment estimate for the Hessian.
        mhat = m / (1 - b1**(i + 1))  # Bias correction.
        vhat = v / (1 - b2**(i + 1))
        x = x - step_size(i) * (mhat / (jnp.sqrt(vhat)**hessian_power + eps) +
                                weight_decay * x)
        return x, m, v

    def get_params(state):
        x, _, _ = state
        return x

    return init, update, get_params
示例#11
0
def rmomentum(step_size, manifold, mass):
    """Construct optimizer triple for stochastic gradient descent.
    Args:
        step_size:
            positive scalar, or a callable representing a step size schedule
            that maps the iteration index to positive scalar.
        manifold:
            the manifold to perform riemannian optimization on.
        mass:
            positive scaler representing the momentum coefficient

    Returns:
        An (init_fun, update_fun, get_params) triple.
    """
    step_size = make_schedule(step_size)
    def init(x0):
        return x0, jax.numpy.zeros_like(x0)

    def update(i, grad, state):
        '''
        x, velocity = state
        velocity = mass * velocity + g
        x = x - step_size(i) * velocity
        return x, velocity
        '''
        x, velocity = state
        rgrad = manifold.egrad_to_rgrad(x, grad)
        # velocity = mass * velocity + rgrad  # both are in tangent space Tx
        velocity = mass * velocity + (1 - mass) * rgrad  # both are in tangent space Tx
        new_x, velocity =\
                manifold.retraction_transport(x,
                                              velocity,
                                              -step_size(i) * velocity)
        return new_x, velocity

    def get_params(state):
        x, _ = state
        return x

    return init, update, get_params
示例#12
0
def gradient_ascent(step_size):
    """Construct optimizer triple for stochastic gradient descent.

    Args:
      step_size: positive scalar, or a callable representing a step size schedule
        that maps the iteration index to positive scalar.

    Returns:
      An (init_fun, update_fun, get_params) triple.
    """
    step_size = make_schedule(step_size)

    def init(x0):
        return x0

    def update(i, g, x):
        return x + step_size(i) * g

    def get_params(x):
        return x

    return init, update, get_params
示例#13
0
def rsgd(step_size, manifold):
    """Construct optimizer triple for stochastic gradient descent.
    Args:
        step_size: positive scalar, or a callable representing a step size schedule
        that maps the iteration index to positive scalar.

    Returns:
        An (init_fun, update_fun, get_params) triple.
    """
    step_size = make_schedule(step_size)

    def init(x0):
        return x0

    def update(i, grad, x):
        rgrad = manifold.egrad_to_rgrad(x, grad)
        new_x = manifold.retraction(x, -step_size(i) * rgrad)
        return new_x

    def get_params(x):
        return x

    return init, update, get_params
示例#14
0
def madam(step_size=0.01, b2=0.999, g_bound=10):
    step_size = optimizers.make_schedule(step_size)

    def init(x0):
        s0 = np.sqrt(np.mean(x0 * x0))  # Initial scale.
        v0 = np.zeros_like(x0)  # 2nd moment.
        return x0, s0, v0

    def update(i, g, state):
        x, s, v = state
        v = (1 - b2) * np.square(g) + b2 * v  # Update 2nd moment.
        vhat = v / (1 - b2**(i + 1))  # Bias correction.
        g_norm = np.nan_to_num(g / np.sqrt(vhat))  # Normalise gradient.
        g_norm = np.clip(g_norm, -g_bound, g_bound)  # Bound g.
        x *= np.exp(-step_size(i) * g_norm *
                    np.sign(x))  # Multiplicative update.
        x = np.clip(x, -s, s)  # Bound parameters.
        return x, s, v

    def get_params(state):
        x, s, v = state
        return x

    return init, update, get_params
示例#15
0
def main(unused_argv):
    from jax.api import grad, jit, vmap, pmap, device_put
    "The following is required to use TPU Driver as JAX's backend."

    if FLAGS.TPU:
        config.FLAGS.jax_xla_backend = "tpu_driver"
        config.FLAGS.jax_backend_target = "grpc://" + os.environ[
            'TPU_ADDR'] + ':8470'
        TPU_ADDR = os.environ['TPU_ADDR']
    ndevices = xla_bridge.device_count()
    if not FLAGS.TPU:
        ndevices = 1

    pmap = partial(pmap, axis_name='i')
    """Setup some experiment parameters."""
    meas_step = FLAGS.meas_step
    training_epochs = int(FLAGS.epochs)

    tmult = 1.0
    if FLAGS.physical:
        tmult = FLAGS.lr
        if FLAGS.physicalL2:
            tmult = FLAGS.L2 * tmult
    if FLAGS.physical:
        training_epochs = 1 + int(FLAGS.epochs / tmult)

    print('Evolving for {:}e'.format(training_epochs))
    losst = FLAGS.losst
    learning_rate = FLAGS.lr
    batch_size_per_device = FLAGS.bs
    N = FLAGS.N
    K = FLAGS.K

    batch_size = batch_size_per_device * ndevices
    steps_per_epoch = 50000 // batch_size
    training_steps = training_epochs * steps_per_epoch

    "Filename from FLAGS"

    filename = 'wrnL2_' + losst + '_n' + str(N) + '_k' + str(K)
    if FLAGS.momentum:
        filename += '_mom'
    if FLAGS.L2_sch:
        filename += '_L2sch' + '_decay' + str(FLAGS.L2dec) + '_del' + str(
            FLAGS.delay)
    if FLAGS.seed != 1:
        filename += 'seed' + str(FLAGS.seed)
    filename += '_L2' + str(FLAGS.L2)
    if FLAGS.std_wrn_sch:
        filename += '_stddec'
        if FLAGS.physical:
            filename += 'phys'
    else:
        filename += '_ctlr'
    if not FLAGS.augment:
        filename += '_noaug'
    if not FLAGS.mix:
        filename += '_nomixup'
    filename += '_bs' + str(batch_size) + '_lr' + str(learning_rate)
    if FLAGS.jobdir is not None:
        filedir = os.path.join('wrnlogs', FLAGS.jobdir)
    else:
        filedir = 'wrnlogs'
    if not os.path.exists(filedir):
        os.makedirs(filedir)
    filedir = os.path.join(filedir, filename + '.csv')

    print('Saving log to ', filename)
    print('Found {} cores.'.format(ndevices))
    """Load CIFAR10 data and create a minimal pipeline."""

    train_images, train_labels, test_images, test_labels = utils.load_data(
        'cifar10')
    train_images = np.reshape(train_images, (-1, 32, 32 * 3))
    train = (train_images, train_labels)
    test = (test_images, test_labels)
    k = train_labels.shape[-1]
    train = utils.shard_data(train, ndevices)
    test = utils.shard_data(test, ndevices)
    """Create a Wide Resnet and replicate its parameters across the devices."""

    initparams, f, _ = utils.WideResnetnt(N, K, k)

    "Loss and optimizer definitions"

    l2_norm = lambda params: tree_map(lambda x: np.sum(x**2), params)
    l2_reg = lambda params: tree_reduce(lambda x, y: x + y, l2_norm(params))
    currL2 = FLAGS.L2
    L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, )))

    def xentr(params, images_and_labels):
        images, labels = images_and_labels
        return -np.mean(stax.logsoftmax(f(params, images)) * labels)

    def mse(params, data_tuple):
        """MSE loss."""
        x, y = data_tuple
        return 0.5 * np.mean((y - f(params, x))**2)

    if losst == 'xentr':
        print('Using xentr')
        lossm = xentr
    else:
        print('Using mse')
        lossm = mse

    loss = lambda params, data, L2: lossm(params, data) + L2 * l2_reg(params)

    def accuracy(params, images_and_labels):
        images, labels = images_and_labels
        return np.mean(
            np.array(np.argmax(f(params, images), axis=1) == np.argmax(labels,
                                                                       axis=1),
                     dtype=np.float32))

    "Define optimizer"

    if FLAGS.std_wrn_sch:
        lr = learning_rate
        first_epoch = int(60 / 200 * training_epochs)
        learning_rate_fn = optimizers.piecewise_constant(
            np.array([1, 2, 3]) * first_epoch * steps_per_epoch,
            np.array([lr, lr * 0.2, lr * 0.2**2, lr * 0.2**3]))
    else:
        learning_rate_fn = optimizers.make_schedule(learning_rate)

    if FLAGS.momentum:
        momentum = 0.9
    else:
        momentum = 0

    @pmap
    def update_step(step, state, batch_state, L2):
        batch, batch_state = batch_fn(batch_state)
        params = get_params(state)
        dparams = grad_loss(params, batch, L2)
        dparams = tree_map(lambda x: lax.psum(x, 'i') / ndevices, dparams)
        return step + 1, apply_fn(step, dparams, state), batch_state

    @pmap
    def evaluate(state, data, L2):
        params = get_params(state)
        lossmm = lossm(params, data)
        l2mm = l2_reg(params)
        return lossmm + L2 * l2mm, accuracy(params, data), lossmm, l2mm

    "Initialization and loading"

    _, params = initparams(random.PRNGKey(0), (-1, 32, 32, 3))
    replicate_array = lambda x: \
        np.broadcast_to(x, (ndevices,) + x.shape)
    replicated_params = tree_map(replicate_array, params)

    grad_loss = jit(grad(loss))
    init_fn, apply_fn, get_params = optimizers.momentum(
        learning_rate_fn, momentum)
    apply_fn = jit(apply_fn)
    key = random.PRNGKey(FLAGS.seed)

    batchinit_fn, batch_fn = utils.sharded_minibatcher(batch_size,
                                                       ndevices,
                                                       transform=FLAGS.augment,
                                                       k=k,
                                                       mix=FLAGS.mix)

    batch_state = pmap(batchinit_fn)(random.split(key, ndevices), train)
    state = pmap(init_fn)(replicated_params)

    if FLAGS.checkpointing:
        ## Loading of checkpoint if available/provided.
        single_state = init_fn(params)
        i0, load_state, load_params, filename0, batch_stateb = utils.load_weights(
            filename,
            single_state,
            params,
            full_file=FLAGS.load_w,
            ndevices=ndevices)
        if i0 is not None:
            filename = filename0
            if batch_stateb is not None:
                batch_state = batch_stateb
            if load_params is not None:
                state = pmap(init_fn)(load_params)
            else:
                state = load_state
        else:
            i0 = 0
    else:
        i0 = 0

    if FLAGS.steps_from_load:
        training_steps = i0 + training_steps

    batch_xs, _ = pmap(batch_fn)(batch_state)

    train_loss = []
    train_accuracy = []
    lrL = []
    test_loss = []
    test_accuracy = []
    test_L2, test_lm, train_lm, train_L2 = [], [], [], []
    L2_t = []
    idel0 = i0
    start = time.time()

    step = pmap(lambda x: x)(i0 * np.ones((ndevices, )))

    "Start training loop"
    if FLAGS.checkpointing:
        print('Evolving for {:}e and saving every {:}s'.format(
            training_epochs, FLAGS.checkpointing))

    print(
        'Epoch\tLearning Rate\tTrain bareLoss\t L2_norm \tTest Loss\tTrain Error\tTest Error\tTime / Epoch'
    )

    for i in range(i0, training_steps):
        if i % meas_step == 0:
            # Make Measurement
            l, a, lm, L2m = evaluate(state, test, L2p)
            test_loss += [np.mean(l)]
            test_accuracy += [np.mean(a)]
            test_lm += [np.mean(lm)]
            test_L2 += [np.mean(L2m)]
            train_batch, _ = pmap(batch_fn)(batch_state)
            l, a, lm, L2m = evaluate(state, train_batch, L2p)

            train_loss += [np.mean(l)]
            train_accuracy += [np.mean(a)]
            train_lm += [np.mean(lm)]
            train_L2 += [np.mean(L2m)]
            L2_t.append(currL2)
            lrL += [learning_rate_fn(i)]

            if FLAGS.L2_sch and i > FLAGS.delay / currL2 + idel0 and len(
                    train_lm) > 2 and ((minloss <= train_lm[-1]
                                        and minloss <= train_lm[-2]) or
                                       (maxacc >= train_accuracy[-1]
                                        and maxacc >= train_accuracy[-2])):
                # If AutoL2 is on and we are beyond the refractory period, decay if the loss or error have increased in the last two measurements.
                print('Decaying L2 to', currL2 / FLAGS.L2dec)
                currL2 = currL2 / FLAGS.L2dec
                L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, )))
                idel0 = i

            elif FLAGS.L2_sch and len(train_lm) >= 2:
                # Update the minimum values.
                try:
                    maxacc = max(train_accuracy[-2], maxacc)
                    minloss = min(train_lm[-2], minloss)
                except:
                    maxacc, minloss = train_accuracy[-2], train_lm[-2]

            if i % (meas_step * 10) == 0 or i == i0:
                # Save measurements to csv
                epoch = batch_size * i / 50000
                dt = (time.time() - start) / (meas_step * 10) * steps_per_epoch
                print(('{}\t' + ('{: .4f}\t' * 7)).format(
                    epoch, learning_rate_fn(i), train_lm[-1], train_L2[-1],
                    test_loss[-1], train_accuracy[-1], test_accuracy[-1], dt))

                start = time.time()
                data = {
                    'train_loss': train_loss,
                    'test_loss': test_loss,
                    'train_acc': train_accuracy,
                    'test_acc': test_accuracy
                }
                data['train_bareloss'] = train_lm
                data['train_L2'] = train_L2
                data['test_bareloss'] = test_lm
                data['test_L2'] = test_L2
                data['L2_t'] = L2_t
                df = pd.DataFrame(data)

                df['learning_rate'] = lrL
                df['width'] = K
                df['batch_size'] = batch_size
                df['step'] = i0 + onp.arange(0, len(train_loss)) * meas_step

                df.to_csv(filedir, index=False)

        if FLAGS.checkpointing:
            ### SAVE MODEL
            if i % FLAGS.checkpointing == 0 and i > i0:

                if not os.path.exists('weights/'):
                    os.makedirs('weights/')
                saveparams = tree_flatten(state[0])[0]
                if ndevices > 1:
                    saveparams = [el[0] for el in saveparams]
                saveparams = np.concatenate(
                    [el.reshape(-1) for el in saveparams])

                step0 = i
                print('Step', i)
                print('saving at', filename, step0, 'size:', saveparams.shape)

                utils.save_weights(filename, step0, saveparams, batch_state)

        ## UPDATE
        step, state, batch_state = update_step(step, state, batch_state, L2p)

    print('Training done')

    if FLAGS.TPU:
        with open('done/' + TPU_ADDR, 'w') as fp:
            fp.write(filedir)
            pass
示例#16
0
文件: base.py 项目: lindermanlab/jxf
    def proximal_optimizer(cls, prior=None, step_size=0.75, **kwargs):
        """Return an optimizer triplet, like jax.experimental.optimizers,
        to perform proximal gradient ascent on the likelihood with a penalty
        on the KL divergence between distributions from one iteration to the
        next. This boils down to taking a convex combination of sufficient
        statistics from this data and those that have been accumulated from
        past data.

        Returns:

            initial_state    :: dictionary of optimizer state
                                (sufficient statistics and number of datapoints)
            update           :: minibatch, itr, state -> state
            get_distribution :: state -> Distribution object
        """
        initial_state = dict(suff_stats=None, num_datapoints=0.0)
        schedule = make_schedule(step_size)

        @format_dataset
        def update(itr,
                   dataset,
                   state,
                   weights=None,
                   suff_stats=None,
                   num_datapoints=0.0,
                   scale_factor=1.0):

            # Compute the sufficient statistics and the number of datapoints
            if suff_stats is None:
                num_datapoints = 0.0
                for data_dict, these_weights in zip(dataset, weights):
                    these_stats = cls.sufficient_statistics(
                        **data_dict, **kwargs)

                    # weight the statistics if weights are given
                    if these_weights is not None:
                        these_stats = tuple(
                            np.tensordot(these_weights, s, axes=(0, 0))
                            for s in these_stats)
                    else:
                        these_stats = tuple(s.sum(axis=0) for s in these_stats)

                    # add to our accumulated statistics
                    suff_stats = sum_tuples(suff_stats, these_stats)

                    # update the number of datapoints
                    num_datapoints += these_weights.sum()
            else:
                # assume suff_stats and num_datapoints are given
                pass

            # Scale the sufficient statistics by the given scale factor.
            # This is as if the sufficient statistics were accumulated
            # from the entire dataset rather than a batch.
            suff_stats = tuple(scale_factor * ss for ss in suff_stats)
            num_datapoints = scale_factor * num_datapoints

            # Take a convex combination of sufficient statistics from
            # this batch and those accumulated thus far.
            if state["suff_stats"] is not None:
                state["suff_stats"] = convex_combination(
                    state["suff_stats"], suff_stats, schedule(itr))

                state["num_datapoints"] = convex_combination(
                    state["num_datapoints"], num_datapoints, schedule(itr))
            else:
                state = dict(suff_stats=suff_stats,
                             num_datapoints=num_datapoints)

            return state

        def get_distribution(state):
            # Update parameters with the average stats
            return cls.fit_with_stats(state["suff_stats"],
                                      state["num_datapoints"],
                                      prior=prior,
                                      **kwargs)

        return initial_state, update, get_distribution
示例#17
0
文件: optimizers.py 项目: byzhang/d3p
def adadp(step_size=1e-3,
          tol=1.0,
          stability_check=True,
          alpha_min=0.9,
          alpha_max=1.1):
    """Construct optimizer triple for the adaptive learning rate optimizer of
    Koskela and Honkela.

    Reference:
    A. Koskela, A. Honkela: Learning Rate Adaptation for Federated and
    Differentially Private Learning (https://arxiv.org/abs/1809.03832).

    Args:
    step_size: the initial step size
    tol: error tolerance for the discretized gradient steps
    stability_check: settings to True rejects some updates in favor of a more
        stable algorithm
    alpha_min: lower multiplitcative bound of learning rate update per step
    alpha_max: upper multiplitcative bound of learning rate update per step

    Returns:
        An (init_fun, update_fun, get_params) triple.
    """
    step_size = make_schedule(step_size)

    def init(x0):
        lr = step_size(0)
        x_stepped = tree_map(lambda n: jnp.zeros_like(n), x0)
        return x0, lr, x_stepped, x0

    def _compute_update_step(x, g, step_size_):
        return tree_multimap(lambda x_, g_: x_ - step_size_ * g_, x, g)

    def _update_even_step(args):
        g, state, new_x = args
        x, lr, x_stepped, x_prev = state

        x_prev = x
        x_stepped = _compute_update_step(x, g, lr)

        return new_x, lr, x_stepped, x_prev

    def _update_odd_step(args):
        g, state, new_x = args
        x, lr, x_stepped, x_prev = state

        norm_partials = tree_multimap(
            lambda x_full, x_halfs: jnp.sum(
                ((x_full - x_halfs) / jnp.maximum(1., x_full))**2), x_stepped,
            new_x)

        err_e = jnp.array(tree_leaves(norm_partials))
        # note(lumip): paper specifies the approximate error function as
        #   using absolute values, but since we square anyways, those are
        #   not required here; the resulting array is partial squared sums
        #   of the l2-norm over all gradient elements (per gradient site)

        err_e = jnp.sqrt(jnp.sum(err_e))  # summing partial gradient norm

        new_lr = lr * jnp.minimum(jnp.maximum(jnp.sqrt(tol / err_e), 0.9), 1.1)

        new_x = tree_multimap(
            lambda x_prev, new_x: jnp.where(stability_check and err_e > tol,
                                            x_prev, new_x), x_prev, new_x)

        return new_x, new_lr, x_stepped, x_prev

    def update(i, g, state):
        x, lr, x_stepped, x_prev = state

        new_x = _compute_update_step(x, g, 0.5 * lr)
        return lax.cond(i % 2 == 0, (g, state, new_x), _update_even_step,
                        (g, state, new_x), _update_odd_step)

    def get_params(state):
        x = state[0]
        return x

    return init, update, get_params
示例#18
0
def main(_):
    logging.info('Starting experiment.')
    configs = FLAGS.config

    # Create model folder for outputs
    try:
        gfile.MakeDirs(FLAGS.exp_dir)
    except gfile.GOSError:
        pass
    stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.exp_dir), 'w+')

    logging.info('Loading data.')
    tic = time.time()

    train_images, train_labels, _ = datasets.get_dataset_split(
        FLAGS.dataset, 'train')
    n_train = len(train_images)
    train_mu, train_std = onp.mean(train_images), onp.std(train_images)
    train = data.DataChunk(X=(train_images - train_mu) / train_std,
                           Y=train_labels,
                           image_size=32,
                           image_channels=3,
                           label_dim=1,
                           label_format='numeric')

    test_images, test_labels, _ = datasets.get_dataset_split(
        FLAGS.dataset, 'test')
    test = data.DataChunk(
        X=(test_images - train_mu) / train_std,  # normalize w train mean/std
        Y=test_labels,
        image_size=32,
        image_channels=3,
        label_dim=1,
        label_format='numeric')

    # Data augmentation
    if configs.augment_data:
        augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5),
                                             data.RandomCrop(4), data.ToDevice)
    else:
        augmentation = None
    batch = data.minibatcher(train, configs.batch_size, transform=augmentation)

    # Model architecture
    if configs.architect == 'wrn':
        init_random_params, predict = wide_resnet(configs.block_size,
                                                  configs.channel_multiplier,
                                                  10)
    elif configs.architect == 'cnn':
        init_random_params, predict = cnn()
    else:
        raise ValueError('Model architecture not implemented.')

    if configs.seed is not None:
        key = random.PRNGKey(configs.seed)
    else:
        key = random.PRNGKey(int(time.time()))
    _, params = init_random_params(key, (-1, 32, 32, 3))

    # count params of JAX model
    def count_parameters(params):
        return tree_util.tree_reduce(
            operator.add, tree_util.tree_map(lambda x: np.prod(x.shape),
                                             params))

    logging.info('Number of parameters: %d', count_parameters(params))
    stdout_log.write('Number of params: {}\n'.format(count_parameters(params)))

    # loss functions
    def cross_entropy_loss(params, x_img, y_lbl):
        return -np.mean(stax.logsoftmax(predict(params, x_img)) * y_lbl)

    def mse_loss(params, x_img, y_lbl):
        return 0.5 * np.mean((y_lbl - predict(params, x_img))**2)

    def accuracy(y_lbl_hat, y_lbl):
        target_class = np.argmax(y_lbl, axis=1)
        predicted_class = np.argmax(y_lbl_hat, axis=1)
        return np.mean(predicted_class == target_class)

    # Loss and gradient
    if configs.loss == 'xent':
        loss = cross_entropy_loss
    elif configs.loss == 'mse':
        loss = mse_loss
    else:
        raise ValueError('Loss function not implemented.')
    grad_loss = jit(grad(loss))

    # learning rate schedule and optimizer
    def cosine(initial_step_size, train_steps):
        k = np.pi / (2.0 * train_steps)

        def schedule(i):
            return initial_step_size * np.cos(k * i)

        return schedule

    if configs.optimization == 'sgd':
        lr_schedule = optimizers.make_schedule(configs.learning_rate)
        opt_init, opt_update, get_params = optimizers.sgd(lr_schedule)
    elif configs.optimization == 'momentum':
        lr_schedule = cosine(configs.learning_rate, configs.train_steps)
        opt_init, opt_update, get_params = optimizers.momentum(
            lr_schedule, 0.9)
    else:
        raise ValueError('Optimizer not implemented.')

    opt_state = opt_init(params)

    def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier,
                     batch_size):
        """Return differentially private gradients of params, evaluated on batch."""
        def _clipped_grad(params, single_example_batch):
            """Evaluate gradient for a single-example batch and clip its grad norm."""
            grads = grad_loss(params, single_example_batch[0].reshape(
                (-1, 32, 32, 3)), single_example_batch[1])

            nonempty_grads, tree_def = tree_util.tree_flatten(grads)
            total_grad_norm = np.linalg.norm(
                [np.linalg.norm(neg.ravel()) for neg in nonempty_grads])
            divisor = stop_gradient(
                np.amax((total_grad_norm / l2_norm_clip, 1.)))
            normalized_nonempty_grads = [
                neg / divisor for neg in nonempty_grads
            ]
            return tree_util.tree_unflatten(tree_def,
                                            normalized_nonempty_grads)

        px_clipped_grad_fn = vmap(partial(_clipped_grad, params))
        std_dev = l2_norm_clip * noise_multiplier
        noise_ = lambda n: n + std_dev * random.normal(rng, n.shape)
        normalize_ = lambda n: n / float(batch_size)
        sum_ = lambda n: np.sum(n, 0)  # aggregate
        aggregated_clipped_grads = tree_util.tree_map(
            sum_, px_clipped_grad_fn(batch))
        noised_aggregated_clipped_grads = tree_util.tree_map(
            noise_, aggregated_clipped_grads)
        normalized_noised_aggregated_clipped_grads = (tree_util.tree_map(
            normalize_, noised_aggregated_clipped_grads))
        return normalized_noised_aggregated_clipped_grads

    # summarize measurements
    steps_per_epoch = n_train // configs.batch_size

    def summarize(step, params):
        """Compute measurements in a zipped way."""
        set_entries = [train, test]
        set_bsizes = [configs.train_eval_bsize, configs.test_eval_bsize]
        set_names, loss_dict, acc_dict = ['train', 'test'], {}, {}

        for set_entry, set_bsize, set_name in zip(set_entries, set_bsizes,
                                                  set_names):
            temp_loss, temp_acc, points = 0.0, 0.0, 0
            for b in data.batch(set_entry, set_bsize):
                temp_loss += loss(params, b.X, b.Y) * b.X.shape[0]
                temp_acc += accuracy(predict(params, b.X), b.Y) * b.X.shape[0]
                points += b.X.shape[0]
            loss_dict[set_name] = temp_loss / float(points)
            acc_dict[set_name] = temp_acc / float(points)

        logging.info('Step: %s', str(step))
        logging.info('Train acc : %.4f', acc_dict['train'])
        logging.info('Train loss: %.4f', loss_dict['train'])
        logging.info('Test acc  : %.4f', acc_dict['test'])
        logging.info('Test loss : %.4f', loss_dict['test'])

        stdout_log.write('Step: {}\n'.format(step))
        stdout_log.write('Train acc : {}\n'.format(acc_dict['train']))
        stdout_log.write('Train loss: {}\n'.format(loss_dict['train']))
        stdout_log.write('Test acc  : {}\n'.format(acc_dict['test']))
        stdout_log.write('Test loss : {}\n'.format(loss_dict['test']))

        return acc_dict['test']

    toc = time.time()
    logging.info('Elapsed SETUP time: %s', str(toc - tic))
    stdout_log.write('Elapsed SETUP time: {}\n'.format(toc - tic))

    # BEGIN: training steps
    logging.info('Training network.')
    tic = time.time()
    t = time.time()

    for s in range(configs.train_steps):
        b = next(batch)
        params = get_params(opt_state)

        # t0 = time.time()
        if FLAGS.dpsgd:
            key = random.fold_in(key, s)  # get new key for new random numbers
            opt_state = opt_update(
                s,
                private_grad(params, (b.X.reshape(
                    (-1, 1, 32, 32, 3)), b.Y), key, configs.l2_norm_clip,
                             configs.noise_multiplier, configs.batch_size),
                opt_state)
        else:
            opt_state = opt_update(s, grad_loss(params, b.X, b.Y), opt_state)
        # t1 = time.time()
        # logging.info('batch update time: %s', str(t1 - t0))

        if s % steps_per_epoch == 0:
            with gfile.Open(
                    '{}/ckpt_{}'.format(FLAGS.exp_dir,
                                        int(s / steps_per_epoch)),
                    'wr') as fckpt:
                pickle.dump(optimizers.unpack_optimizer_state(opt_state),
                            fckpt)

            if FLAGS.dpsgd:
                eps = compute_epsilon(s, configs.batch_size, n_train,
                                      configs.target_delta,
                                      configs.noise_multiplier)
                stdout_log.write(
                    'For delta={:.0e}, current epsilon is: {:.2f}\n'.format(
                        configs.target_delta, eps))

            logging.info('Elapsed EPOCH time: %s', str(time.time() - t))
            stdout_log.write('Elapsed EPOCH time: {}'.format(time.time() - t))
            stdout_log.flush()
            t = time.time()

    toc = time.time()
    summarize(configs.train_steps, params)
    logging.info('Elapsed TRAIN time: %s', str(toc - tic))
    stdout_log.write('Elapsed TRAIN time: {}'.format(toc - tic))
    stdout_log.close()
示例#19
0
 def __init__(self, step_size=0.01):
     self.step_size = experimental.make_schedule(step_size)
示例#20
0
文件: cga.py 项目: niklasschmitz/fax
def cga(step_size_f,
        step_size_g,
        f,
        g,
        linear_op_solver=None,
        default_max_iter=1000,
        solve_order='alternating'):

    if linear_op_solver is None:

        def default_convergence_test(x_new, x_old):
            min_type = converge.tree_smallest_float_dtype(x_new)
            rtol, atol = converge.adjust_tol_for_dtype(1e-10, 1e-10, min_type)
            return converge.max_diff_test(x_new, x_old, rtol, atol)

        def default_solver(linear_op, bvec, init_x=None):
            if init_x is None:
                init_x = bvec

            def _step_default_solver(i, x):
                del i
                return tree_util.tree_multimap(lax.add, linear_op(x), bvec)

            return loop.fixed_point_iteration(
                init_x=init_x,
                func=_step_default_solver,
                convergence_test=default_convergence_test,
                max_iter=default_max_iter,
            )

        linear_op_solver = default_solver

    step_size_f = optimizers.make_schedule(step_size_f)
    step_size_g = optimizers.make_schedule(step_size_g)

    def init(inputs):
        delta_x, delta_y = tree_util.tree_map(np.zeros_like, inputs)
        return CGAState(
            x=inputs[0],
            y=inputs[1],
            delta_x=delta_x,
            delta_y=delta_y,
        )

    def update(i, grads, inputs, *args, **kwargs):
        if len(inputs) < 4:
            x, y = inputs
            delta_x = None
            delta_y = None
        else:
            x, y, delta_x, delta_y = inputs

        grad_xf, grad_yg = grads

        eta_f = step_size_f(i)
        eta_g = step_size_g(i)
        eta_fg = eta_g * eta_f

        jvp_xyf = make_mixed_jvp(partial(f, *args, **kwargs), x, y)
        jvp_yxg = make_mixed_jvp(partial(g, *args, **kwargs),
                                 x,
                                 y,
                                 opposite=True)

        def linear_op_x(x):
            return tree_util.tree_map(lambda v: eta_fg * v,
                                      jvp_xyf(jvp_yxg(x)))

        def linear_op_y(y):
            return tree_util.tree_map(lambda v: eta_fg * v,
                                      jvp_yxg(jvp_xyf(y)))

        def solve_delta_x(init_x):

            bx = tree_util.tree_multimap(
                lambda grad_xf, z: grad_xf + eta_g * z,
                grad_xf,
                jvp_xyf(grad_yg),
            )

            delta_x = linear_op_solver(linear_op=linear_op_x,
                                       bvec=bx,
                                       init_x=init_x).value

            return delta_x

        def solve_delta_y(init_y):

            by = tree_util.tree_multimap(
                lambda z, grad_yg: grad_yg + eta_f * z, jvp_yxg(grad_xf),
                grad_yg)

            delta_y = linear_op_solver(linear_op=linear_op_y,
                                       bvec=by,
                                       init_x=init_y).value

            return delta_y

        def solve_x_update_y(deltas):
            delta_x, _ = deltas
            delta_x = solve_delta_x(delta_x)
            delta_y = tree_util.tree_multimap(lambda g_y, v: (g_y + eta_f * v),
                                              grad_yg, jvp_yxg(delta_x))
            return delta_x, delta_y

        def solve_y_update_x(deltas):
            _, delta_y = deltas
            delta_y = solve_delta_y(delta_y)
            delta_x = tree_util.tree_multimap(lambda g_x, v: (g_x + eta_g * v),
                                              grad_xf, jvp_xyf(delta_y))
            return delta_x, delta_y

        def solve_both(deltas):
            delta_x, delta_y = deltas
            delta_x = solve_delta_x(delta_x)
            delta_y = solve_delta_y(delta_y)
            return delta_x, delta_y

        def solve_alternating(deltas):
            return lax.cond(
                np.mod(i, 2).astype(bool), deltas, solve_x_update_y, deltas,
                solve_y_update_x)

        solver = {
            'simultaneous': solve_both,
            'alternating': solve_alternating,
            'xy': solve_x_update_y,
            'yx': solve_y_update_x
        }

        delta_x, delta_y = solver[solve_order]((delta_x, delta_y))

        x = tree_util.tree_multimap(lambda x, delta_x: x + eta_f * delta_x, x,
                                    delta_x)
        y = tree_util.tree_multimap(lambda y, delta_y: y + eta_g * delta_y, y,
                                    delta_y)
        return CGAState(x, y, delta_x, delta_y)

    def get_params(state):
        return state[:2]

    return init, update, get_params