예제 #1
0
def main(unused_argv):

    train_size = FLAGS.train_size
    x_train, y_train, x_test, y_test = pickle.load(
        open("data_" + str(train_size) + ".p", "rb"))
    print("Got data")
    sys.stdout.flush()

    # Build the network
    init_fn, apply_fn, _ = stax.serial(
        stax.Dense(2048, 1., 0.05),
        # stax.Erf(),
        stax.Relu(),
        stax.Dense(1, 1., 0.05))

    # initialize the network first time, to compute NTK
    randnnn = numpy.random.random_integers(np.iinfo(np.int32).min,
                                           high=np.iinfo(np.int32).max,
                                           size=2)[0]
    key = random.PRNGKey(randnnn)
    _, params = init_fn(key, (-1, 784))

    # Create an MSE predictor to solve the NTK equation in function space.
    # we assume that the NTK is approximately the same for any sample of parameters (true in the limit of infinite width)

    print("Making NTK")
    sys.stdout.flush()
    ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=1)
    g_dd = ntk(x_train, None, params)
    pickle.dump(g_dd, open("ntk_train_" + str(FLAGS.train_size) + ".p", "wb"))
    g_td = ntk(x_test, x_train, params)
    pickle.dump(g_td,
                open("ntk_train_test_" + str(FLAGS.train_size) + ".p", "wb"))
    predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td)
예제 #2
0
def threefry_random_bits(key: jnp.ndarray, bit_width, shape):
    """Sample uniform random bits of given width and shape using PRNG key."""
    if not _is_threefry_prng_key(key):
        raise TypeError("_random_bits got invalid prng key.")
    if bit_width not in (8, 16, 32, 64):
        raise TypeError("requires 8-, 16-, 32- or 64-bit field width.")
    shape = core.as_named_shape(shape)
    for name, size in shape.named_items:
        real_size = lax.psum(1, name)
        if real_size != size:
            raise ValueError(
                f"The shape of axis {name} was specified as {size}, "
                f"but it really is {real_size}")
        axis_index = lax.axis_index(name)
        key = threefry_fold_in(key, axis_index)
    size = prod(shape.positional)
    # Compute ceil(bit_width * size / 32) in a way that is friendly to shape
    # polymorphism
    max_count, r = divmod(bit_width * size, 32)
    if r > 0:
        max_count += 1

    if core.is_constant_dim(max_count):
        nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max)
    else:
        nblocks, rem = 0, max_count

    if not nblocks:
        bits = threefry_2x32(key, lax.iota(np.uint32, rem))
    else:
        keys = threefry_split(key, nblocks + 1)
        subkeys, last_key = keys[:-1], keys[-1]
        blocks = vmap(threefry_2x32,
                      in_axes=(0, None))(subkeys,
                                         lax.iota(np.uint32,
                                                  jnp.iinfo(np.uint32).max))
        last = threefry_2x32(last_key, lax.iota(np.uint32, rem))
        bits = lax.concatenate([blocks.ravel(), last], 0)

    dtype = UINT_DTYPES[bit_width]
    if bit_width == 64:
        bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)]
        bits = lax.shift_left(bits[0], dtype(32)) | bits[1]
    elif bit_width in [8, 16]:
        # this is essentially bits.view(dtype)[:size]
        bits = lax.bitwise_and(
            np.uint32(np.iinfo(dtype).max),
            lax.shift_right_logical(
                lax.broadcast(bits, (1, )),
                lax.mul(
                    np.uint32(bit_width),
                    lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0))))
        bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width), ),
                           (1, 0))
        bits = lax.convert_element_type(bits, dtype)[:size]
    return lax.reshape(bits, shape)
예제 #3
0
파일: v.py 프로젝트: sleepy-owl/coax
    def example_data(cls,
                     env,
                     observation_preprocessor=None,
                     batch_size=1,
                     random_seed=None):

        if not isinstance(env.observation_space, Space):
            raise TypeError(
                "env.observation_space must be derived from gym.Space, "
                f"got: {type(env.observation_space)}")

        if observation_preprocessor is None:
            observation_preprocessor = default_preprocessor(
                env.observation_space)

        rnd = onp.random.RandomState(random_seed)
        rngs = hk.PRNGSequence(rnd.randint(jnp.iinfo('int32').max))

        # input: state observations
        S = [
            safe_sample(env.observation_space, rnd) for _ in range(batch_size)
        ]
        S = [observation_preprocessor(next(rngs), s) for s in S]
        S = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *S)

        return ExampleData(
            inputs=Inputs(args=ArgsType2(S=S, is_training=True),
                          static_argnums=(1, )),
            output=jnp.asarray(rnd.randn(batch_size)),
        )
예제 #4
0
    def example_data(cls,
                     env,
                     observation_preprocessor,
                     action_preprocessor,
                     proba_dist,
                     batch_size=1,
                     random_seed=None):

        if not isinstance(env.observation_space, Space):
            raise TypeError(
                "env.observation_space must be derived from gym.Space, "
                f"got: {type(env.observation_space)}")
        if not isinstance(env.action_space, Space):
            raise TypeError(
                f"env.action_space must be derived from gym.Space, got: {type(env.action_space)}"
            )

        rnd = onp.random.RandomState(random_seed)
        rngs = hk.PRNGSequence(rnd.randint(jnp.iinfo('int32').max))

        # these must be provided
        assert observation_preprocessor is not None
        assert action_preprocessor is not None
        assert proba_dist is not None

        # input: state observations
        S = [
            safe_sample(env.observation_space, rnd) for _ in range(batch_size)
        ]
        S = [observation_preprocessor(next(rngs), s) for s in S]
        S = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *S)

        # input: actions
        A = [safe_sample(env.action_space, rnd) for _ in range(batch_size)]
        A = [action_preprocessor(next(rngs), a) for a in A]
        A = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *A)

        # output: type1
        dist_params_type1 = jax.tree_map(
            lambda x: jnp.asarray(rnd.randn(batch_size, *x.shape[1:])),
            proba_dist.default_priors)
        data_type1 = ExampleData(inputs=Inputs(args=ArgsType1(
            S=S, A=A, is_training=True),
                                               static_argnums=(2, )),
                                 output=dist_params_type1)

        if not isinstance(env.action_space, Discrete):
            return ModelTypes(type1=data_type1, type2=None)

        # output: type2 (if actions are discrete)
        dist_params_type2 = jax.tree_map(
            lambda x: jnp.asarray(
                rnd.randn(batch_size, env.action_space.n, *x.shape[1:])),
            proba_dist.default_priors)
        data_type2 = ExampleData(inputs=Inputs(args=ArgsType2(
            S=S, is_training=True),
                                               static_argnums=(1, )),
                                 output=dist_params_type2)

        return ModelTypes(type1=data_type1, type2=data_type2)
예제 #5
0
def dtype_min_value(dtype):
    if dtype.kind == 'f':
        return -jnp.inf
    elif dtype.kind == 'i':
        return jnp.iinfo(dtype).min
    elif dtype.kind == 'b':
        return False
    else:
        raise ValueError(f'Invalid data type {dtype.kind!r}.')
예제 #6
0
def dtype_max_value(dtype):
    if dtype.kind == 'f':
        return jnp.inf
    elif dtype.kind == 'i':
        return jnp.iinfo(dtype).max
    elif dtype.kind == 'b':
        return True
    else:
        raise ValueError(f'Invalid data type {dtype.kind!r}.')
예제 #7
0
def check_preprocessors(space,
                        *preprocessors,
                        num_samples=20,
                        random_seed=None):
    r"""

    Check whether two preprocessors are the same.

    Parameters
    ----------
    space : gym.Space

        The domain of the prepocessors.

    \*preprocessors

        Preprocessor functions, which are functions with input signature: :code:`func(rng: PRNGKey,
        x: Element[space]) -> Any`.

    num_samples : positive int

        The number of samples in which to run checks.

    Returns
    -------
    match : bool

        Whether the preprocessors match.

    """
    if len(preprocessors) < 2:
        raise ValueError(
            "need at least two preprocessors in order to run test")

    def test_leaves(a, b):
        assert type(a) is type(b)
        return onp.testing.assert_allclose(onp.asanyarray(a),
                                           onp.asanyarray(b))

    rngs = hk.PRNGSequence(
        onp.random.RandomState(random_seed).randint(jnp.iinfo('int32').max))
    p0, *ps = preprocessors

    with jax.disable_jit():
        for _ in range(num_samples):
            x = space.sample()
            y0 = p0(next(rngs), x)
            for p in ps:
                y = p(next(rngs), x)
                if jax.tree_structure(y) != jax.tree_structure(y0):
                    return False
                try:
                    jax.tree_multimap(test_leaves, y, y0)
                except AssertionError:
                    return False
    return True
예제 #8
0
파일: utils.py 프로젝트: deepmind/acme
def sample_uint32(random_key: jax_types.PRNGKey) -> int:
    """Returns an integer uniformly distributed in 0..2^32-1."""
    iinfo = jnp.iinfo(jnp.int32)
    # randint only accepts int32 values as min and max.
    jax_random = jax.random.randint(random_key,
                                    shape=(),
                                    minval=iinfo.min,
                                    maxval=iinfo.max,
                                    dtype=jnp.int32)
    return np.uint32(jax_random).item()
예제 #9
0
  def update_fn(updates, state, params=None):
    step_size = step_size_fn(state.count) * decay
    updates = jax.tree_multimap(lambda u, p: u - step_size * p, updates, params)

    # does a _safe_int32_increment
    max_int32_value = jnp.iinfo(jnp.int32).max
    new_count = jnp.where(state.count < max_int32_value,
                          state.count + 1,
                          max_int32_value)
    new_state = DecoupledWeightDecayState(count=new_count, step_size=step_size)

    return updates, new_state
예제 #10
0
파일: prng.py 프로젝트: Jakob-Unfried/jax
def _make_rotate_left(dtype):
  if not jnp.issubdtype(dtype, np.integer):
    raise TypeError("_rotate_left only accepts integer dtypes.")
  nbits = np.array(jnp.iinfo(dtype).bits, dtype)

  def _rotate_left(x, d):
    if lax.dtype(d) != dtype:
      d = lax.convert_element_type(d, dtype)
    if lax.dtype(x) != dtype:
      x = lax.convert_element_type(x, dtype)
    return lax.shift_left(x, d) | lax.shift_right_logical(x, nbits - d)
  return _rotate_left
예제 #11
0
def safe_int32_increment(count: chex.Numeric) -> chex.Numeric:
    """Increments int32 counter by one.
    Normally `max_int + 1` would overflow to `min_int`. This functions ensures
    that when `max_int` is reached the counter stays at `max_int`.
    Args:
      count: a counter to be incremented.
    Returns:
      A counter incremented by 1, or max_int if the maximum precision is reached.
    """
    chex.assert_type(count, jnp.int32)
    max_int32_value = jnp.iinfo(jnp.int32).max
    one = jnp.array(1, dtype=jnp.int32)
    return jnp.where(count < max_int32_value, count + one, max_int32_value)
예제 #12
0
    def example_data(
            cls, env, observation_preprocessor=None, action_preprocessor=None,
            batch_size=1, random_seed=None):

        if not isinstance(env.observation_space, Space):
            raise TypeError(
                "env.observation_space must be derived from gym.Space, "
                f"got: {type(env.observation_space)}")

        if not isinstance(env.action_space, Space):
            raise TypeError(
                f"env.action_space must be derived from gym.Space, got: {type(env.action_space)}")

        if observation_preprocessor is None:
            observation_preprocessor = ProbaDist(env.observation_space).preprocess_variate

        if action_preprocessor is None:
            action_preprocessor = default_preprocessor(env.action_space)

        rnd = onp.random.RandomState(random_seed)
        rngs = hk.PRNGSequence(rnd.randint(jnp.iinfo('int32').max))

        # input: state observations
        S = [safe_sample(env.observation_space, rnd) for _ in range(batch_size)]
        S = [observation_preprocessor(next(rngs), s) for s in S]
        S = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *S)

        # input: actions
        A = [safe_sample(env.action_space, rnd) for _ in range(batch_size)]
        A = [action_preprocessor(next(rngs), a) for a in A]
        A = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *A)

        # output: type1
        S_next_type1 = jax.tree_map(lambda x: jnp.asarray(rnd.randn(batch_size, *x.shape[1:])), S)
        q1_data = ExampleData(
            inputs=Inputs(args=ArgsType1(S=S, A=A, is_training=True), static_argnums=(2,)),
            output=S_next_type1)

        if not isinstance(env.action_space, Discrete):
            return ModelTypes(type1=q1_data, type2=None)

        # output: type2 (if actions are discrete)
        S_next_type2 = jax.tree_map(
            lambda x: jnp.asarray(rnd.randn(batch_size, env.action_space.n, *x.shape[1:])), S)
        q2_data = ExampleData(
            inputs=Inputs(args=ArgsType2(S=S, is_training=True), static_argnums=(1,)),
            output=S_next_type2)

        return ModelTypes(type1=q1_data, type2=q2_data)
예제 #13
0
def observe(state, seed=None):
    """Observes the classical state of a quantum system.

    Collapses the quantum state into a classical state, sampling
    from the distribution of classical states given by the amplitudes
    in the quantum system described by state.

    :param state: Vector representation of quantum state to observe.
    :param seed: Optional seed with which to sample randomly.
    :return: Integer in [0, state.shape[0])
    """
    seed = seed if seed is not None else random.randint(
        0,
        np.iinfo(np.int32).max)
    key = jax.random.PRNGKey(seed)
    p = np.real(np.conj(state) * state)
    return jax.random.categorical(key, np.log(p))
예제 #14
0
파일: utils_test.py 프로젝트: BwRy/jraph
    def test_segment_max_negatives(self, indices_are_sorted, unique_indices):
        neg_inf = jnp.iinfo(jnp.int32).min
        if unique_indices:
            data = -1 - jnp.arange(6)  # [-1, -2, -3, -4, -5, -6]
            if indices_are_sorted:
                segment_ids = jnp.array([0, 1, 2, 3, 4, 5])
                expected_out = jnp.array([-1, -2, -3, -4, -5, -6])
                num_segments = 6
            else:
                segment_ids = jnp.array([1, 0, 2, 4, 3, -5])
                expected_out = jnp.array([-2, -1, -3, -5, -4])
                num_segments = 5
        else:
            data = -1 - jnp.arange(9)  # [-1, -2, -3, -4, -5, -6, -7, -8, -9]
            if indices_are_sorted:
                segment_ids = jnp.array([0, 0, 0, 1, 1, 1, 2, 3, 4])
                expected_out = jnp.array([-1, -4, -7, -8, -9, neg_inf])
            else:
                segment_ids = jnp.array([0, 1, 2, 0, 4, 0, 1, 1, -6])
                expected_out = jnp.array([-1, -2, -3, neg_inf, -5, neg_inf])
            num_segments = 6

        with self.subTest('nojit'):
            result = utils.segment_max(data, segment_ids, num_segments,
                                       indices_are_sorted, unique_indices)
            self.assertAllClose(result, expected_out, check_dtypes=True)
            result = utils.segment_max(data,
                                       segment_ids,
                                       indices_are_sorted=indices_are_sorted,
                                       unique_indices=unique_indices)
            num_unique_segments = jnp.maximum(
                jnp.max(segment_ids) + 1, jnp.max(-segment_ids))
            self.assertAllClose(result,
                                expected_out[:num_unique_segments],
                                check_dtypes=True)
        with self.subTest('jit'):
            result = jax.jit(utils.segment_max,
                             static_argnums=(2, 3, 4))(data, segment_ids,
                                                       num_segments,
                                                       indices_are_sorted,
                                                       unique_indices)
            self.assertAllClose(result, expected_out, check_dtypes=True)
예제 #15
0
def main(unused_argv):
    using_SGD = FLAGS.using_SGD
    train_size = FLAGS.train_size
    x_train, y_train, x_test, y_test = pickle.load(
        open("data_" + str(train_size) + ".p", "rb"))
    print("Got data")
    sys.stdout.flush()

    train_size = FLAGS.train_size

    # Build the network
    init_fn, apply_fn, _ = stax.serial(
        stax.Dense(2048, 1., 0.05),
        # stax.Erf(),
        stax.Relu(),
        stax.Dense(1, 1., 0.05))

    ##ONLY IMPLEMENTED MSE LOSS AND 0,1 LABELS for now

    # initialize the network first time, to compute NTK
    randnnn = numpy.random.random_integers(np.iinfo(np.int32).min,
                                           high=np.iinfo(np.int32).max,
                                           size=2)[0]
    key = random.PRNGKey(randnnn)
    _, params = init_fn(key, (-1, 784))

    # Create an MSE predictor to solve the NTK equation in function space.
    # we assume that the NTK is approximately the same for any sample of parameters (true in the limit of infinite width)
    sys.stdout.flush()
    g_dd = pickle.load(open("ntk_train_" + str(FLAGS.train_size) + ".p", "rb"))
    g_td = pickle.load(
        open("ntk_train_test_" + str(FLAGS.train_size) + ".p", "rb"))
    print("Got NTK")
    if not using_SGD:
        predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td)

    batch_size = FLAGS.batch_size

    from mpi4py import MPI
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()
    print(rank)
    for i in range(FLAGS.num_samples):
        if i % (ceil(FLAGS.num_samples / 100)) == 0:
            print(i)
            sys.stdout.flush()
        #reinitialize the network
        randnnn = numpy.random.random_integers(np.iinfo(np.int32).min,
                                               high=np.iinfo(np.int32).max,
                                               size=2)[0]
        key = random.PRNGKey(randnnn)
        _, params = init_fn(key, (-1, 784))

        # Get initial values of the network in function space.
        fx_train = apply_fn(params, x_train)
        fx_test = apply_fn(params, x_test)

        if using_SGD:
            error = 1
            lr = 0.1
            lr = nt.predict.max_learning_rate(g_dd)
            print(lr)
            # lr *= 0.05
            # lr*=1
            ntk_train = g_dd.squeeze()
            ntk_train_test = g_td.squeeze()
            if batch_size == train_size:
                indices = np.array(list(range(train_size)))
            while error >= 0.5:
                if batch_size != train_size:
                    indices = numpy.random.choice(range(train_size),
                                                  size=batch_size,
                                                  replace=False)
                fx_test = fx_test - lr * np.matmul(
                    ntk_train_test[:, indices],
                    (fx_train[indices] - y_train[indices])) / (2 * batch_size)
                fx_train = fx_train - lr * np.matmul(
                    ntk_train[:, indices],
                    (fx_train[indices] - y_train[indices])) / (2 * batch_size)
                # fx_train = jax.ops.index_add(fx_train, indices, -lr*np.matmul(ntk_train[:,indices],(fx_train[indices]-y_train[indices]))/(2*batch_size))
                # print(fx_train[0:10])
                error = np.dot(
                    (fx_train - y_train).squeeze(),
                    (fx_train - y_train).squeeze()) / (2 * train_size)
                #print(error)
        else:
            # Get predictions from analytic computation.
            fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)

        OUTPUT = fx_test > 0.5
        OUTPUT = OUTPUT.astype(int)
        #print(np.transpose(OUTPUT))
        fun = ''.join([str(int(i)) for i in OUTPUT])
        fun
        TRUE_OUTPUT = y_test > 0.5
        TRUE_OUTPUT = TRUE_OUTPUT.astype(int)
        #print(np.transpose(OUTPUT))
        ''.join([str(int(i)) for i in TRUE_OUTPUT])
        print("Generalization accuracy",
              np.sum(OUTPUT == TRUE_OUTPUT) / FLAGS.test_size)

        loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
        #util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss)
        #util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)

        OUTPUT = fx_train > 0.5
        OUTPUT = OUTPUT.astype(int)
        #print(np.transpose(OUTPUT))
        ''.join([str(int(i)) for i in OUTPUT])
        TRUE_OUTPUT = y_train > 0.5
        TRUE_OUTPUT = TRUE_OUTPUT.astype(int)
        #print(np.transpose(OUTPUT))
        ''.join([str(int(i)) for i in TRUE_OUTPUT])
        print("Training accuracy",
              np.sum(OUTPUT == TRUE_OUTPUT) / FLAGS.train_size)
        assert np.all(OUTPUT == TRUE_OUTPUT)

        file = open('results/data_{}_large.txt'.format(rank), 'a')
        file.write(fun + '\n')
    file.close()
예제 #16
0
import inspect
from functools import partial

import jax.numpy as jnp
from jax import jit

from onnx_jax.handlers.backend_handler import BackendHandler
from onnx_jax.handlers.handler import onnx_op
from onnx_jax.pb_wrapper import OnnxNode

int32_max = jnp.iinfo(jnp.int32).max


@onnx_op("Slice")
class Slice(BackendHandler):
    @classmethod
    def _common(cls, node: OnnxNode, **kwargs):
        cls._rewrite(node)
        cls._prepare(node)

        def _slice(x, starts, ends, axes=None, steps=None):
            if axes is not None:
                axes = tuple(axes)
            if steps is not None:
                steps = tuple(steps)
            ends = [x if x < int32_max else int32_max for x in ends]
            return onnx_slice(x, tuple(starts), tuple(ends), axes, steps)

        return _slice

    @classmethod
예제 #17
0
def _safe_int32_increment(count):
  chex.assert_type(count, jnp.int32)
  max_int32_value = jnp.iinfo(jnp.int32).max
  one = jnp.array(1, dtype=jnp.int32)
  return jnp.where(count < max_int32_value, count + one, max_int32_value)
예제 #18
0
def main(unused_argv):
    loss = FLAGS.loss
    train_size = FLAGS.train_size
    x_train, y_train, x_test, y_test = pickle.load(
        open("data_" + str(train_size) + ".p", "rb"))
    print("Got data")
    sys.stdout.flush()

    # Build the network
    init_fn, apply_fn, _ = stax.serial(stax.Dense(512, 1.,
                                                  0.0357), stax.Relu(),
                                       stax.Dense(512, 1., 0.0357),
                                       stax.Relu(), stax.Dense(1, 1., 0.0357))

    opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
    #opt_init, opt_apply, get_params = optimizers.adam(0.001)

    # Create an mse loss function and a gradient function.
    if loss == "mse":
        loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
        decision_threshold = 0.5
    elif loss == "ce":
        #loss = lambda fx, yhat: np.mean( (yhat)*np.log(1+np.exp(-fx)) + (1-yhat)*(np.log(1+np.exp(-fx))+fx) )
        loss = lambda fx, y: np.sum(
            np.max(fx, 0) - fx * y + np.log(1 + np.exp(-np.abs(fx))))
        decision_threshold = 0.0
    else:
        raise NotImplementedError()
    grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y)))

    batch_size = FLAGS.batch_size

    from mpi4py import MPI
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()
    print(rank)
    for i in range(FLAGS.num_samples):
        #reinitialize the network
        randnnn = numpy.random.random_integers(np.iinfo(np.int32).min,
                                               high=np.iinfo(np.int32).max,
                                               size=2)[0]
        key = random.PRNGKey(randnnn)
        _, params = init_fn(key, (-1, 784))
        state = opt_init(params)
        # Get initial values of the network in function space.
        fx_train = apply_fn(params, x_train)
        # fx_test = apply_fn(params, x_test)

        OUTPUT = (fx_train > 0).astype(int)
        TRUE_OUTPUT = (y_train > 0).astype(int)
        train_acc = np.sum(OUTPUT == TRUE_OUTPUT) / FLAGS.train_size
        #train_acc_batch = 0
        while train_acc < 1.0:
            if batch_size != train_size:
                indices = numpy.random.choice(range(train_size),
                                              size=batch_size,
                                              replace=False)
            else:
                indices = np.array(list(range(train_size)))
            x_batch = x_train[indices]
            y_batch = y_train[indices]
            state = opt_apply(i, grad_loss(params, x_batch, y_batch), state)
            params = get_params(state)
            fx_train_batch = apply_fn(params, x_batch)
            OUTPUT_batch = (fx_train_batch > decision_threshold).astype(int)
            TRUE_OUTPUT_batch = (y_batch > decision_threshold).astype(int)
            train_acc_batch = np.sum(
                OUTPUT_batch == TRUE_OUTPUT_batch) / FLAGS.batch_size
            #print(fx_train_batch)
            print(train_acc_batch)
            if train_acc_batch == 1.0:
                fx_train = apply_fn(params, x_train)
                OUTPUT = (fx_train > decision_threshold).astype(int)
                TRUE_OUTPUT = (y_train > decision_threshold).astype(int)
                train_acc = np.sum(OUTPUT == TRUE_OUTPUT) / FLAGS.train_size
                print("train_acc", train_acc)

        fx_train = apply_fn(params, x_train)
        fx_test = apply_fn(params, x_test)

        OUTPUT = (fx_train > decision_threshold).astype(int)
        #print(np.transpose(OUTPUT))
        # ''.join([str(int(i)) for i in OUTPUT])
        TRUE_OUTPUT = (y_train > decision_threshold).astype(int)
        train_acc = np.sum(OUTPUT == TRUE_OUTPUT) / FLAGS.train_size
        print("Training accuracy", train_acc)
        assert train_acc == 1.0

        OUTPUT = fx_test > decision_threshold
        OUTPUT = OUTPUT.astype(int)
        fun = ''.join([str(int(i)) for i in OUTPUT])
        TRUE_OUTPUT = y_test > decision_threshold
        TRUE_OUTPUT = TRUE_OUTPUT.astype(int)
        ''.join([str(int(i)) for i in TRUE_OUTPUT])
        test_acc = np.sum(OUTPUT == TRUE_OUTPUT) / FLAGS.test_size
        print("Generalization accuracy", test_acc)

        file = open('data_{}_large.txt'.format(rank), 'a')
        file.write(fun + '\n')
        file.close()
예제 #19
0
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
      datasets.mnist(FLAGS.train_size, FLAGS.test_size)

    # x_train
    import numpy
    # numpy.argmax(y_train,1)%2
    # y_train_tmp = numpy.zeros((y_train.shape[0],2))
    # y_train_tmp[np.arange(y_train.shape[0]),numpy.argmax(y_train,1)%2] = 1
    # y_train = y_train_tmp
    # y_test_tmp = numpy.zeros((y_test.shape[0],2))
    # y_test_tmp[np.arange(y_train.shape[0]),numpy.argmax(y_test,1)%2] = 1
    # y_test = y_test_tmp

    y_train_tmp = numpy.argmax(y_train, 1) % 2
    y_train = np.expand_dims(y_train_tmp, 1)
    y_test_tmp = numpy.argmax(y_test, 1) % 2
    y_test = np.expand_dims(y_test_tmp, 1)
    # print(y_train)
    # Build the network
    # init_fn, apply_fn, _ = stax.serial(
    #   stax.Dense(2048, 1., 0.05),
    #   # stax.Erf(),
    #   stax.Relu(),
    #   stax.Dense(2048, 1., 0.05),
    #   # stax.Erf(),
    #   stax.Relu(),
    #   stax.Dense(10, 1., 0.05))
    init_fn, apply_fn, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(),
                                       stax.Dense(1, 1., 0.05))

    # key = random.PRNGKey(0)
    randnnn = numpy.random.random_integers(np.iinfo(np.int32).min,
                                           high=np.iinfo(np.int32).max,
                                           size=2)[0]
    key = random.PRNGKey(randnnn)
    _, params = init_fn(key, (-1, 784))

    # params

    # Create and initialize an optimizer.
    opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
    state = opt_init(params)
    # state

    # Create an mse loss function and a gradient function.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y)))

    # Create an MSE predictor to solve the NTK equation in function space.
    ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0)
    g_dd = ntk(x_train, None, params)
    g_td = ntk(x_test, x_train, params)
    predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td)
    # g_dd.shape

    # Get initial values of the network in function space.
    fx_train = apply_fn(params, x_train)
    fx_test = apply_fn(params, x_test)

    # Train the network.
    train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
    print('Training for {} steps'.format(train_steps))

    for i in range(train_steps):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x_train, y_train), state)

    # Get predictions from analytic computation.
    print('Computing analytic prediction.')
    # fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)
    fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, apply_fn(params, x_train), fx_train,
                       loss)
    util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
예제 #20
0
파일: ops.py 프로젝트: ordabayevy/funsor
def _safesub(x, y):
    try:
        finfo = np.finfo(y.dtype)
    except ValueError:
        finfo = np.iinfo(y.dtype)
    return x + np.clip(-y, a_min=None, a_max=finfo.max)
예제 #21
0
파일: ops.py 프로젝트: ordabayevy/funsor
def _safediv(x, y):
    try:
        finfo = np.finfo(y.dtype)
    except ValueError:
        finfo = np.iinfo(y.dtype)
    return x * np.clip(np.reciprocal(y), a_min=None, a_max=finfo.max)