Пример #1
0
 def test_tree_unflatten(self):
     tree = [(1, 2), {"roy": (3, [4, 5])}]
     flat, treedef = tree_flatten(tree)
     assert flat == [1, 2, 3, 4, 5]
     tree2 = tree_unflatten(flat, treedef)
     nodes_equal = tree_multimap(operator.eq, tree, tree2)
     assert tree_reduce(operator.and_, nodes_equal)
Пример #2
0
def _project_on_columns(A, v):
  """
  Returns A.T.conj() @ v.
  """
  v_proj = tree_map(
      lambda X, y: _einsum("...n,...->n", X.conj(), y), A, v,
  )
  return tree_reduce(operator.add, v_proj)
Пример #3
0
 def minimum_tolerances(
     self, augmented: ComparingState[State, Comparand]
 ) -> Tuple[RealNumeric, RealNumeric]:
     """
     Returns:
         The minimum value of atol that would lead to convergence now.
         The minimum value of rtol that would lead to convergence now.
     """
     comparand = self.extract_comparand(augmented.current_state)
     abs_last = tree_map(jnp.abs, augmented.last_state)
     delta = tree_map(
         jnp.abs, tree_map(jnp.subtract, comparand, augmented.last_state))
     delta_over_b = tree_map(divide_nonnegative, delta, abs_last)
     minium_atol = tree_reduce(jnp.maximum, tree_map(jnp.amax, delta), 0.0)
     minium_rtol = tree_reduce(jnp.maximum,
                               tree_map(jnp.amax, delta_over_b), 0.0)
     return minium_atol, minium_rtol
Пример #4
0
def max_diff_test(x_new, x_old, rtol, atol):
    def check_values(x, y):
        delta = np.max(np.abs(x - y))
        abs_old = np.max(np.abs(x))
        return close_or_nan(delta, abs_old, rtol, atol)

    is_close = tree_util.tree_multimap(check_values, x_new, x_old)
    return tree_util.tree_reduce(operator.and_, is_close)
Пример #5
0
 def converged(
         self, augmented: ComparingState[State,
                                         Comparand]) -> BooleanNumeric:
     return tree_reduce(
         jnp.logical_and,
         tree_map(partial(jnp.allclose, rtol=self.rtol, atol=self.atol),
                  self.extract_comparand(augmented.current_state),
                  augmented.last_state), True)
Пример #6
0
    def sum_and_contract(j1, j2):
        def contract(x, y):
            param_count = int(np.prod(x.shape[2:]))
            x = np.reshape(x, x.shape[:2] + (param_count, ))
            y = np.reshape(y, y.shape[:2] + (param_count, ))
            return np.dot(x, np.transpose(y, (0, 2, 1)))

        return tree_reduce(operator.add, tree_multimap(contract, j1, j2))
Пример #7
0
 def gradient_variance(grad_fn, hyperparameters):
     """
     Variance of Gradient w.r.t. hyperparameters
     :param grad_fn: Gradient computing function w.r.t. tunable hyperparameters
     :param hyperparameters: cv_coeff and log_temperature (dict)
     """
     gradients = grad_fn(**hyperparameters)
     var = tree_reduce(lambda x, y: x + y,
                       tree_map(lambda x: np.sum(x**2), gradients))
     return var
Пример #8
0
  def sum_and_contract(j1, j2, output_ndim):
    _diagonal_axes = utils.canonicalize_axis(diagonal_axes, output_ndim)
    _trace_axes = utils.canonicalize_axis(trace_axes, output_ndim)

    def contract(x, y):
      param_axes = list(range(x.ndim))[output_ndim:]
      contract_axes = _trace_axes + param_axes
      return utils.dot_general(x, y, contract_axes, _diagonal_axes)

    return tree_reduce(operator.add, tree_multimap(contract, j1, j2))
Пример #9
0
    def gradient_variance(grad_fn, hyperparameters):
        """"
        Variance of the gradient w.r.t. surrogate parameters.

        :param grad_fn: Gradient computing function w.r.t. tunable hyperparameters
        :param hyperparameters: parameters of the RELAX surrogate (list)
        """
        gradients = grad_fn(
            surrogate_params=hyperparameters['surrogate_params'])
        var = tree_reduce(lambda x, y: x + y,
                          tree_map(lambda x: np.sum(x**2), gradients))
        return var
Пример #10
0
  def sum_and_contract(fx, j1, j2):
    ndim = fx.ndim
    size = utils.size_at(fx, trace_axes)

    _diagonal_axes = utils.canonicalize_axis(diagonal_axes, ndim)
    _trace_axes = utils.canonicalize_axis(trace_axes, ndim)

    def contract(x, y):
      param_axes = list(range(x.ndim))[ndim:]
      contract_axes = _trace_axes + param_axes
      return utils.dot_general(x, y, contract_axes, _diagonal_axes) / size

    return tree_reduce(operator.add, tree_map(contract, j1, j2))
Пример #11
0
 def converged(
         self, augmented: StochasticState[State,
                                          Comparand]) -> BooleanNumeric:
     data_weight = leaky_data_weight(augmented.iterations,
                                     self.convergence_detection_decay)
     mean_squared = tree_map(abs_square, augmented.mean_state)
     return tree_reduce(
         jnp.logical_and,
         tree_map(
             partial(jnp.allclose,
                     rtol=self.rtol * data_weight,
                     atol=self.atol * data_weight),
             augmented.second_moment_state, mean_squared), True)
Пример #12
0
    def minimum_tolerances(
        self, augmented: StochasticState[State, Comparand]
    ) -> Tuple[RealNumeric, RealNumeric]:
        """
        Returns:
            The minimum value of atol that would lead to convergence now.
            The minimum value of rtol that would lead to convergence now.
        """
        data_weight = leaky_data_weight(augmented.iterations,
                                        self.convergence_detection_decay)
        mean_squared = tree_map(abs_square, augmented.mean_state)
        variance = tree_map(jnp.subtract, augmented.second_moment_state,
                            mean_squared)
        scaled_variance = tree_map(divide_nonnegative, variance, mean_squared)

        minimum_atol = divide_nonnegative(
            tree_reduce(jnp.maximum, tree_map(jnp.amax, variance), 0.0),
            data_weight)
        minimum_rtol = divide_nonnegative(
            tree_reduce(jnp.maximum, tree_map(jnp.amax, scaled_variance), 0.0),
            data_weight)
        assert not isinstance(minimum_atol, complex)
        assert not isinstance(minimum_rtol, complex)
        return minimum_atol, minimum_rtol
Пример #13
0
def tree_allclose(actual: PyTree,
                  desired: PyTree,
                  rtol: Optional[float] = None,
                  atol: Optional[float] = None) -> bool:
    """
    Args:
        actual: The actual value.
        desired: The desired value.
        rtol: The relative tolerance of the comparisons in the comparison.
        atol: The absolute tolerance of the comparisons in the comparison.
    """
    def allclose(actual_array: Array, desired_array: Array) -> BooleanNumeric:
        dtype = jnp.result_type(actual_array, desired_array)
        tols = default_tols(dtype.type, rtol=rtol,
                            atol=atol)  # pyright: ignore
        return bool(jnp.allclose(actual_array, desired_array, **tols))

    return tree_reduce(jnp.logical_and, tree_map(allclose, actual, desired),
                       True)
Пример #14
0
def adam(vg_fun, loader, params0, epochs=10, eta=0.01, gamma=0.9, disp=None):
    # parameter info
    params = tree_map(np.array, params0)

    # track rms gradient
    grms = tree_map(np.zeros_like, params)

    # do training
    for ep in range(epochs):
        # epoch stats
        agg_loss, agg_batch = 0.0, 0

        # iterate over batches
        for b, batch in enumerate(loader):
            # compute gradients
            loss, grad = vg_fun(params, batch)

            lnan = np.isnan(loss)
            gnan = tree_reduce(and_, tree_map(lambda g: np.isnan(g).any(),
                                              grad))

            if lnan or gnan:
                print('Encountered nans!')
                return params, None

            grms = tree_map(lambda r, g: gamma * r + (1 - gamma) * g**2, grms,
                            grad)
            params = tree_map(lambda p, g, r: p + eta * g / np.sqrt(r + eps),
                              params, grad, grms)

            # compute statistics
            agg_loss += loss
            agg_batch += 1

        # display stats
        avg_loss = agg_loss / agg_batch

        # display output
        if disp is not None:
            disp(ep, avg_loss, params)

    return params
Пример #15
0
 def hessian_vector_product(v: Weights) -> Weights:
     d = tree_reduce(jnp.add, tree_map(jnp.vdot, gradient, v), 0.0)
     return tree_map(lambda x: x * d, gradient)
Пример #16
0
def _tree_dot(t1, t2):
    assert(len(t1) == len(t2))
    def f(x, y):
        return x + y

    return tu.tree_reduce(lambda x, y: x+y, tu.tree_multimap(f, t1, t2)) #jnp.dot(tu.tree_flatten(t1), tu.tree_flatten(t2))
Пример #17
0
def tree_smallest_float_dtype(x):
    return tree_util.tree_reduce(_min_float_dtype,
                                 tree_util.tree_map(lambda x: x.dtype, x))
Пример #18
0
def main(unused_argv):
    from jax.api import grad, jit, vmap, pmap, device_put
    "The following is required to use TPU Driver as JAX's backend."

    if FLAGS.TPU:
        config.FLAGS.jax_xla_backend = "tpu_driver"
        config.FLAGS.jax_backend_target = "grpc://" + os.environ[
            'TPU_ADDR'] + ':8470'
        TPU_ADDR = os.environ['TPU_ADDR']
    ndevices = xla_bridge.device_count()
    if not FLAGS.TPU:
        ndevices = 1

    pmap = partial(pmap, axis_name='i')
    """Setup some experiment parameters."""
    meas_step = FLAGS.meas_step
    training_epochs = int(FLAGS.epochs)

    tmult = 1.0
    if FLAGS.physical:
        tmult = FLAGS.lr
        if FLAGS.physicalL2:
            tmult = FLAGS.L2 * tmult
    if FLAGS.physical:
        training_epochs = 1 + int(FLAGS.epochs / tmult)

    print('Evolving for {:}e'.format(training_epochs))
    losst = FLAGS.losst
    learning_rate = FLAGS.lr
    batch_size_per_device = FLAGS.bs
    N = FLAGS.N
    K = FLAGS.K

    batch_size = batch_size_per_device * ndevices
    steps_per_epoch = 50000 // batch_size
    training_steps = training_epochs * steps_per_epoch

    "Filename from FLAGS"

    filename = 'wrnL2_' + losst + '_n' + str(N) + '_k' + str(K)
    if FLAGS.momentum:
        filename += '_mom'
    if FLAGS.L2_sch:
        filename += '_L2sch' + '_decay' + str(FLAGS.L2dec) + '_del' + str(
            FLAGS.delay)
    if FLAGS.seed != 1:
        filename += 'seed' + str(FLAGS.seed)
    filename += '_L2' + str(FLAGS.L2)
    if FLAGS.std_wrn_sch:
        filename += '_stddec'
        if FLAGS.physical:
            filename += 'phys'
    else:
        filename += '_ctlr'
    if not FLAGS.augment:
        filename += '_noaug'
    if not FLAGS.mix:
        filename += '_nomixup'
    filename += '_bs' + str(batch_size) + '_lr' + str(learning_rate)
    if FLAGS.jobdir is not None:
        filedir = os.path.join('wrnlogs', FLAGS.jobdir)
    else:
        filedir = 'wrnlogs'
    if not os.path.exists(filedir):
        os.makedirs(filedir)
    filedir = os.path.join(filedir, filename + '.csv')

    print('Saving log to ', filename)
    print('Found {} cores.'.format(ndevices))
    """Load CIFAR10 data and create a minimal pipeline."""

    train_images, train_labels, test_images, test_labels = utils.load_data(
        'cifar10')
    train_images = np.reshape(train_images, (-1, 32, 32 * 3))
    train = (train_images, train_labels)
    test = (test_images, test_labels)
    k = train_labels.shape[-1]
    train = utils.shard_data(train, ndevices)
    test = utils.shard_data(test, ndevices)
    """Create a Wide Resnet and replicate its parameters across the devices."""

    initparams, f, _ = utils.WideResnetnt(N, K, k)

    "Loss and optimizer definitions"

    l2_norm = lambda params: tree_map(lambda x: np.sum(x**2), params)
    l2_reg = lambda params: tree_reduce(lambda x, y: x + y, l2_norm(params))
    currL2 = FLAGS.L2
    L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, )))

    def xentr(params, images_and_labels):
        images, labels = images_and_labels
        return -np.mean(stax.logsoftmax(f(params, images)) * labels)

    def mse(params, data_tuple):
        """MSE loss."""
        x, y = data_tuple
        return 0.5 * np.mean((y - f(params, x))**2)

    if losst == 'xentr':
        print('Using xentr')
        lossm = xentr
    else:
        print('Using mse')
        lossm = mse

    loss = lambda params, data, L2: lossm(params, data) + L2 * l2_reg(params)

    def accuracy(params, images_and_labels):
        images, labels = images_and_labels
        return np.mean(
            np.array(np.argmax(f(params, images), axis=1) == np.argmax(labels,
                                                                       axis=1),
                     dtype=np.float32))

    "Define optimizer"

    if FLAGS.std_wrn_sch:
        lr = learning_rate
        first_epoch = int(60 / 200 * training_epochs)
        learning_rate_fn = optimizers.piecewise_constant(
            np.array([1, 2, 3]) * first_epoch * steps_per_epoch,
            np.array([lr, lr * 0.2, lr * 0.2**2, lr * 0.2**3]))
    else:
        learning_rate_fn = optimizers.make_schedule(learning_rate)

    if FLAGS.momentum:
        momentum = 0.9
    else:
        momentum = 0

    @pmap
    def update_step(step, state, batch_state, L2):
        batch, batch_state = batch_fn(batch_state)
        params = get_params(state)
        dparams = grad_loss(params, batch, L2)
        dparams = tree_map(lambda x: lax.psum(x, 'i') / ndevices, dparams)
        return step + 1, apply_fn(step, dparams, state), batch_state

    @pmap
    def evaluate(state, data, L2):
        params = get_params(state)
        lossmm = lossm(params, data)
        l2mm = l2_reg(params)
        return lossmm + L2 * l2mm, accuracy(params, data), lossmm, l2mm

    "Initialization and loading"

    _, params = initparams(random.PRNGKey(0), (-1, 32, 32, 3))
    replicate_array = lambda x: \
        np.broadcast_to(x, (ndevices,) + x.shape)
    replicated_params = tree_map(replicate_array, params)

    grad_loss = jit(grad(loss))
    init_fn, apply_fn, get_params = optimizers.momentum(
        learning_rate_fn, momentum)
    apply_fn = jit(apply_fn)
    key = random.PRNGKey(FLAGS.seed)

    batchinit_fn, batch_fn = utils.sharded_minibatcher(batch_size,
                                                       ndevices,
                                                       transform=FLAGS.augment,
                                                       k=k,
                                                       mix=FLAGS.mix)

    batch_state = pmap(batchinit_fn)(random.split(key, ndevices), train)
    state = pmap(init_fn)(replicated_params)

    if FLAGS.checkpointing:
        ## Loading of checkpoint if available/provided.
        single_state = init_fn(params)
        i0, load_state, load_params, filename0, batch_stateb = utils.load_weights(
            filename,
            single_state,
            params,
            full_file=FLAGS.load_w,
            ndevices=ndevices)
        if i0 is not None:
            filename = filename0
            if batch_stateb is not None:
                batch_state = batch_stateb
            if load_params is not None:
                state = pmap(init_fn)(load_params)
            else:
                state = load_state
        else:
            i0 = 0
    else:
        i0 = 0

    if FLAGS.steps_from_load:
        training_steps = i0 + training_steps

    batch_xs, _ = pmap(batch_fn)(batch_state)

    train_loss = []
    train_accuracy = []
    lrL = []
    test_loss = []
    test_accuracy = []
    test_L2, test_lm, train_lm, train_L2 = [], [], [], []
    L2_t = []
    idel0 = i0
    start = time.time()

    step = pmap(lambda x: x)(i0 * np.ones((ndevices, )))

    "Start training loop"
    if FLAGS.checkpointing:
        print('Evolving for {:}e and saving every {:}s'.format(
            training_epochs, FLAGS.checkpointing))

    print(
        'Epoch\tLearning Rate\tTrain bareLoss\t L2_norm \tTest Loss\tTrain Error\tTest Error\tTime / Epoch'
    )

    for i in range(i0, training_steps):
        if i % meas_step == 0:
            # Make Measurement
            l, a, lm, L2m = evaluate(state, test, L2p)
            test_loss += [np.mean(l)]
            test_accuracy += [np.mean(a)]
            test_lm += [np.mean(lm)]
            test_L2 += [np.mean(L2m)]
            train_batch, _ = pmap(batch_fn)(batch_state)
            l, a, lm, L2m = evaluate(state, train_batch, L2p)

            train_loss += [np.mean(l)]
            train_accuracy += [np.mean(a)]
            train_lm += [np.mean(lm)]
            train_L2 += [np.mean(L2m)]
            L2_t.append(currL2)
            lrL += [learning_rate_fn(i)]

            if FLAGS.L2_sch and i > FLAGS.delay / currL2 + idel0 and len(
                    train_lm) > 2 and ((minloss <= train_lm[-1]
                                        and minloss <= train_lm[-2]) or
                                       (maxacc >= train_accuracy[-1]
                                        and maxacc >= train_accuracy[-2])):
                # If AutoL2 is on and we are beyond the refractory period, decay if the loss or error have increased in the last two measurements.
                print('Decaying L2 to', currL2 / FLAGS.L2dec)
                currL2 = currL2 / FLAGS.L2dec
                L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, )))
                idel0 = i

            elif FLAGS.L2_sch and len(train_lm) >= 2:
                # Update the minimum values.
                try:
                    maxacc = max(train_accuracy[-2], maxacc)
                    minloss = min(train_lm[-2], minloss)
                except:
                    maxacc, minloss = train_accuracy[-2], train_lm[-2]

            if i % (meas_step * 10) == 0 or i == i0:
                # Save measurements to csv
                epoch = batch_size * i / 50000
                dt = (time.time() - start) / (meas_step * 10) * steps_per_epoch
                print(('{}\t' + ('{: .4f}\t' * 7)).format(
                    epoch, learning_rate_fn(i), train_lm[-1], train_L2[-1],
                    test_loss[-1], train_accuracy[-1], test_accuracy[-1], dt))

                start = time.time()
                data = {
                    'train_loss': train_loss,
                    'test_loss': test_loss,
                    'train_acc': train_accuracy,
                    'test_acc': test_accuracy
                }
                data['train_bareloss'] = train_lm
                data['train_L2'] = train_L2
                data['test_bareloss'] = test_lm
                data['test_L2'] = test_L2
                data['L2_t'] = L2_t
                df = pd.DataFrame(data)

                df['learning_rate'] = lrL
                df['width'] = K
                df['batch_size'] = batch_size
                df['step'] = i0 + onp.arange(0, len(train_loss)) * meas_step

                df.to_csv(filedir, index=False)

        if FLAGS.checkpointing:
            ### SAVE MODEL
            if i % FLAGS.checkpointing == 0 and i > i0:

                if not os.path.exists('weights/'):
                    os.makedirs('weights/')
                saveparams = tree_flatten(state[0])[0]
                if ndevices > 1:
                    saveparams = [el[0] for el in saveparams]
                saveparams = np.concatenate(
                    [el.reshape(-1) for el in saveparams])

                step0 = i
                print('Step', i)
                print('saving at', filename, step0, 'size:', saveparams.shape)

                utils.save_weights(filename, step0, saveparams, batch_state)

        ## UPDATE
        step, state, batch_state = update_step(step, state, batch_state, L2p)

    print('Training done')

    if FLAGS.TPU:
        with open('done/' + TPU_ADDR, 'w') as fp:
            fp.write(filedir)
            pass
Пример #19
0
            lax.dynamic_slice(x, start, size)
            for x, start, size in zip(data, slice_start, slice_size)
        ]

        if transform is not None:
            key, subkey = random.split(key)
            batch = transform(subkey, batch)

        i = i + 1
        key, data, i = lax.cond(i >= num_batches, (key, data), shuffle,
                                (key, data, i), lambda x: x)

        return batch, (key, data, i, num_batches)

    return init_fn, batch_fn


# END: shard data pipeline.

# Loss Definition.
_l2_norm = lambda params: tree_map(lambda x: np.sum(x**2), params)
l2_regularization = lambda params: tree_reduce(operator.add, _l2_norm(params))

cross_entropy = lambda y, y_hat: -np.mean(np.sum(y * y_hat, axis=1))


# Learning rate schedule of Cosine.
def cosine_schedule(initial_lr, training_steps):
    return lambda t: initial_lr * 0.5 * (1 + np.cos(t / training_steps * np.pi)
                                         )
Пример #20
0
def pytree_dot(x, y) -> float:
    partial_dot = tree_util.tree_multimap(
        lambda arr1, arr2: np.sum(arr1 * arr2), x, y)
    return tree_util.tree_reduce(lax.add, partial_dot)
Пример #21
0
def _tree_concatentate(x):
    return tree_util.tree_reduce(lambda a, b: np.concatenate((a, b)), x)
Пример #22
0
def inner_prod(xs, ys):
  def contract(x, y):
    return np.real(np.dot(np.conj(x).reshape(-1), y.reshape(-1)))
  return tree_reduce(np.add, tree_map(contract, xs, ys))
Пример #23
0
def pytree_relative_error(x, y):
    partial_error = tree_util.tree_multimap(
        lambda a, b: l2_norm(pytree_sub(a, b)) / (l2_norm(a) + 1e-5), x, y)
    return tree_util.tree_reduce(lax.add, partial_error)
Пример #24
0
def pytree_shape_array_equal(x, y):
    is_eq = tree_util.tree_multimap(
        lambda arr1, arr2: (arr1.shape == arr2.shape), x, y)
    return tree_util.tree_reduce(operator.and_, is_eq)
Пример #25
0
def pytree_array_equal(x, y):
    is_eq = tree_util.tree_multimap(
        lambda arr1, arr2: np.array_equal(arr1, arr2), x, y)
    return tree_util.tree_reduce(operator.and_, is_eq)
Пример #26
0
def reduce_loss_tree(loss_tree: Mapping) -> jnp.array:
    """Reduces a loss tree to a scalar (i.e. jnp.array w/ size 1)."""
    return tree_util.tree_reduce(lambda x, y: x + y, loss_tree)
Пример #27
0
 def count_parameters(params):
     return tree_util.tree_reduce(
         operator.add, tree_util.tree_map(lambda x: np.prod(x.shape),
                                          params))