def main(unused_argv): train_size = FLAGS.train_size x_train, y_train, x_test, y_test = pickle.load( open("data_" + str(train_size) + ".p", "rb")) print("Got data") sys.stdout.flush() # Build the network init_fn, apply_fn, _ = stax.serial( stax.Dense(2048, 1., 0.05), # stax.Erf(), stax.Relu(), stax.Dense(1, 1., 0.05)) # initialize the network first time, to compute NTK randnnn = numpy.random.random_integers(np.iinfo(np.int32).min, high=np.iinfo(np.int32).max, size=2)[0] key = random.PRNGKey(randnnn) _, params = init_fn(key, (-1, 784)) # Create an MSE predictor to solve the NTK equation in function space. # we assume that the NTK is approximately the same for any sample of parameters (true in the limit of infinite width) print("Making NTK") sys.stdout.flush() ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=1) g_dd = ntk(x_train, None, params) pickle.dump(g_dd, open("ntk_train_" + str(FLAGS.train_size) + ".p", "wb")) g_td = ntk(x_test, x_train, params) pickle.dump(g_td, open("ntk_train_test_" + str(FLAGS.train_size) + ".p", "wb")) predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td)
def threefry_random_bits(key: jnp.ndarray, bit_width, shape): """Sample uniform random bits of given width and shape using PRNG key.""" if not _is_threefry_prng_key(key): raise TypeError("_random_bits got invalid prng key.") if bit_width not in (8, 16, 32, 64): raise TypeError("requires 8-, 16-, 32- or 64-bit field width.") shape = core.as_named_shape(shape) for name, size in shape.named_items: real_size = lax.psum(1, name) if real_size != size: raise ValueError( f"The shape of axis {name} was specified as {size}, " f"but it really is {real_size}") axis_index = lax.axis_index(name) key = threefry_fold_in(key, axis_index) size = prod(shape.positional) # Compute ceil(bit_width * size / 32) in a way that is friendly to shape # polymorphism max_count, r = divmod(bit_width * size, 32) if r > 0: max_count += 1 if core.is_constant_dim(max_count): nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max) else: nblocks, rem = 0, max_count if not nblocks: bits = threefry_2x32(key, lax.iota(np.uint32, rem)) else: keys = threefry_split(key, nblocks + 1) subkeys, last_key = keys[:-1], keys[-1] blocks = vmap(threefry_2x32, in_axes=(0, None))(subkeys, lax.iota(np.uint32, jnp.iinfo(np.uint32).max)) last = threefry_2x32(last_key, lax.iota(np.uint32, rem)) bits = lax.concatenate([blocks.ravel(), last], 0) dtype = UINT_DTYPES[bit_width] if bit_width == 64: bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)] bits = lax.shift_left(bits[0], dtype(32)) | bits[1] elif bit_width in [8, 16]: # this is essentially bits.view(dtype)[:size] bits = lax.bitwise_and( np.uint32(np.iinfo(dtype).max), lax.shift_right_logical( lax.broadcast(bits, (1, )), lax.mul( np.uint32(bit_width), lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0)))) bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width), ), (1, 0)) bits = lax.convert_element_type(bits, dtype)[:size] return lax.reshape(bits, shape)
def example_data(cls, env, observation_preprocessor=None, batch_size=1, random_seed=None): if not isinstance(env.observation_space, Space): raise TypeError( "env.observation_space must be derived from gym.Space, " f"got: {type(env.observation_space)}") if observation_preprocessor is None: observation_preprocessor = default_preprocessor( env.observation_space) rnd = onp.random.RandomState(random_seed) rngs = hk.PRNGSequence(rnd.randint(jnp.iinfo('int32').max)) # input: state observations S = [ safe_sample(env.observation_space, rnd) for _ in range(batch_size) ] S = [observation_preprocessor(next(rngs), s) for s in S] S = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *S) return ExampleData( inputs=Inputs(args=ArgsType2(S=S, is_training=True), static_argnums=(1, )), output=jnp.asarray(rnd.randn(batch_size)), )
def example_data(cls, env, observation_preprocessor, action_preprocessor, proba_dist, batch_size=1, random_seed=None): if not isinstance(env.observation_space, Space): raise TypeError( "env.observation_space must be derived from gym.Space, " f"got: {type(env.observation_space)}") if not isinstance(env.action_space, Space): raise TypeError( f"env.action_space must be derived from gym.Space, got: {type(env.action_space)}" ) rnd = onp.random.RandomState(random_seed) rngs = hk.PRNGSequence(rnd.randint(jnp.iinfo('int32').max)) # these must be provided assert observation_preprocessor is not None assert action_preprocessor is not None assert proba_dist is not None # input: state observations S = [ safe_sample(env.observation_space, rnd) for _ in range(batch_size) ] S = [observation_preprocessor(next(rngs), s) for s in S] S = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *S) # input: actions A = [safe_sample(env.action_space, rnd) for _ in range(batch_size)] A = [action_preprocessor(next(rngs), a) for a in A] A = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *A) # output: type1 dist_params_type1 = jax.tree_map( lambda x: jnp.asarray(rnd.randn(batch_size, *x.shape[1:])), proba_dist.default_priors) data_type1 = ExampleData(inputs=Inputs(args=ArgsType1( S=S, A=A, is_training=True), static_argnums=(2, )), output=dist_params_type1) if not isinstance(env.action_space, Discrete): return ModelTypes(type1=data_type1, type2=None) # output: type2 (if actions are discrete) dist_params_type2 = jax.tree_map( lambda x: jnp.asarray( rnd.randn(batch_size, env.action_space.n, *x.shape[1:])), proba_dist.default_priors) data_type2 = ExampleData(inputs=Inputs(args=ArgsType2( S=S, is_training=True), static_argnums=(1, )), output=dist_params_type2) return ModelTypes(type1=data_type1, type2=data_type2)
def dtype_min_value(dtype): if dtype.kind == 'f': return -jnp.inf elif dtype.kind == 'i': return jnp.iinfo(dtype).min elif dtype.kind == 'b': return False else: raise ValueError(f'Invalid data type {dtype.kind!r}.')
def dtype_max_value(dtype): if dtype.kind == 'f': return jnp.inf elif dtype.kind == 'i': return jnp.iinfo(dtype).max elif dtype.kind == 'b': return True else: raise ValueError(f'Invalid data type {dtype.kind!r}.')
def check_preprocessors(space, *preprocessors, num_samples=20, random_seed=None): r""" Check whether two preprocessors are the same. Parameters ---------- space : gym.Space The domain of the prepocessors. \*preprocessors Preprocessor functions, which are functions with input signature: :code:`func(rng: PRNGKey, x: Element[space]) -> Any`. num_samples : positive int The number of samples in which to run checks. Returns ------- match : bool Whether the preprocessors match. """ if len(preprocessors) < 2: raise ValueError( "need at least two preprocessors in order to run test") def test_leaves(a, b): assert type(a) is type(b) return onp.testing.assert_allclose(onp.asanyarray(a), onp.asanyarray(b)) rngs = hk.PRNGSequence( onp.random.RandomState(random_seed).randint(jnp.iinfo('int32').max)) p0, *ps = preprocessors with jax.disable_jit(): for _ in range(num_samples): x = space.sample() y0 = p0(next(rngs), x) for p in ps: y = p(next(rngs), x) if jax.tree_structure(y) != jax.tree_structure(y0): return False try: jax.tree_multimap(test_leaves, y, y0) except AssertionError: return False return True
def sample_uint32(random_key: jax_types.PRNGKey) -> int: """Returns an integer uniformly distributed in 0..2^32-1.""" iinfo = jnp.iinfo(jnp.int32) # randint only accepts int32 values as min and max. jax_random = jax.random.randint(random_key, shape=(), minval=iinfo.min, maxval=iinfo.max, dtype=jnp.int32) return np.uint32(jax_random).item()
def update_fn(updates, state, params=None): step_size = step_size_fn(state.count) * decay updates = jax.tree_multimap(lambda u, p: u - step_size * p, updates, params) # does a _safe_int32_increment max_int32_value = jnp.iinfo(jnp.int32).max new_count = jnp.where(state.count < max_int32_value, state.count + 1, max_int32_value) new_state = DecoupledWeightDecayState(count=new_count, step_size=step_size) return updates, new_state
def _make_rotate_left(dtype): if not jnp.issubdtype(dtype, np.integer): raise TypeError("_rotate_left only accepts integer dtypes.") nbits = np.array(jnp.iinfo(dtype).bits, dtype) def _rotate_left(x, d): if lax.dtype(d) != dtype: d = lax.convert_element_type(d, dtype) if lax.dtype(x) != dtype: x = lax.convert_element_type(x, dtype) return lax.shift_left(x, d) | lax.shift_right_logical(x, nbits - d) return _rotate_left
def safe_int32_increment(count: chex.Numeric) -> chex.Numeric: """Increments int32 counter by one. Normally `max_int + 1` would overflow to `min_int`. This functions ensures that when `max_int` is reached the counter stays at `max_int`. Args: count: a counter to be incremented. Returns: A counter incremented by 1, or max_int if the maximum precision is reached. """ chex.assert_type(count, jnp.int32) max_int32_value = jnp.iinfo(jnp.int32).max one = jnp.array(1, dtype=jnp.int32) return jnp.where(count < max_int32_value, count + one, max_int32_value)
def example_data( cls, env, observation_preprocessor=None, action_preprocessor=None, batch_size=1, random_seed=None): if not isinstance(env.observation_space, Space): raise TypeError( "env.observation_space must be derived from gym.Space, " f"got: {type(env.observation_space)}") if not isinstance(env.action_space, Space): raise TypeError( f"env.action_space must be derived from gym.Space, got: {type(env.action_space)}") if observation_preprocessor is None: observation_preprocessor = ProbaDist(env.observation_space).preprocess_variate if action_preprocessor is None: action_preprocessor = default_preprocessor(env.action_space) rnd = onp.random.RandomState(random_seed) rngs = hk.PRNGSequence(rnd.randint(jnp.iinfo('int32').max)) # input: state observations S = [safe_sample(env.observation_space, rnd) for _ in range(batch_size)] S = [observation_preprocessor(next(rngs), s) for s in S] S = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *S) # input: actions A = [safe_sample(env.action_space, rnd) for _ in range(batch_size)] A = [action_preprocessor(next(rngs), a) for a in A] A = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *A) # output: type1 S_next_type1 = jax.tree_map(lambda x: jnp.asarray(rnd.randn(batch_size, *x.shape[1:])), S) q1_data = ExampleData( inputs=Inputs(args=ArgsType1(S=S, A=A, is_training=True), static_argnums=(2,)), output=S_next_type1) if not isinstance(env.action_space, Discrete): return ModelTypes(type1=q1_data, type2=None) # output: type2 (if actions are discrete) S_next_type2 = jax.tree_map( lambda x: jnp.asarray(rnd.randn(batch_size, env.action_space.n, *x.shape[1:])), S) q2_data = ExampleData( inputs=Inputs(args=ArgsType2(S=S, is_training=True), static_argnums=(1,)), output=S_next_type2) return ModelTypes(type1=q1_data, type2=q2_data)
def observe(state, seed=None): """Observes the classical state of a quantum system. Collapses the quantum state into a classical state, sampling from the distribution of classical states given by the amplitudes in the quantum system described by state. :param state: Vector representation of quantum state to observe. :param seed: Optional seed with which to sample randomly. :return: Integer in [0, state.shape[0]) """ seed = seed if seed is not None else random.randint( 0, np.iinfo(np.int32).max) key = jax.random.PRNGKey(seed) p = np.real(np.conj(state) * state) return jax.random.categorical(key, np.log(p))
def test_segment_max_negatives(self, indices_are_sorted, unique_indices): neg_inf = jnp.iinfo(jnp.int32).min if unique_indices: data = -1 - jnp.arange(6) # [-1, -2, -3, -4, -5, -6] if indices_are_sorted: segment_ids = jnp.array([0, 1, 2, 3, 4, 5]) expected_out = jnp.array([-1, -2, -3, -4, -5, -6]) num_segments = 6 else: segment_ids = jnp.array([1, 0, 2, 4, 3, -5]) expected_out = jnp.array([-2, -1, -3, -5, -4]) num_segments = 5 else: data = -1 - jnp.arange(9) # [-1, -2, -3, -4, -5, -6, -7, -8, -9] if indices_are_sorted: segment_ids = jnp.array([0, 0, 0, 1, 1, 1, 2, 3, 4]) expected_out = jnp.array([-1, -4, -7, -8, -9, neg_inf]) else: segment_ids = jnp.array([0, 1, 2, 0, 4, 0, 1, 1, -6]) expected_out = jnp.array([-1, -2, -3, neg_inf, -5, neg_inf]) num_segments = 6 with self.subTest('nojit'): result = utils.segment_max(data, segment_ids, num_segments, indices_are_sorted, unique_indices) self.assertAllClose(result, expected_out, check_dtypes=True) result = utils.segment_max(data, segment_ids, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices) num_unique_segments = jnp.maximum( jnp.max(segment_ids) + 1, jnp.max(-segment_ids)) self.assertAllClose(result, expected_out[:num_unique_segments], check_dtypes=True) with self.subTest('jit'): result = jax.jit(utils.segment_max, static_argnums=(2, 3, 4))(data, segment_ids, num_segments, indices_are_sorted, unique_indices) self.assertAllClose(result, expected_out, check_dtypes=True)
def main(unused_argv): using_SGD = FLAGS.using_SGD train_size = FLAGS.train_size x_train, y_train, x_test, y_test = pickle.load( open("data_" + str(train_size) + ".p", "rb")) print("Got data") sys.stdout.flush() train_size = FLAGS.train_size # Build the network init_fn, apply_fn, _ = stax.serial( stax.Dense(2048, 1., 0.05), # stax.Erf(), stax.Relu(), stax.Dense(1, 1., 0.05)) ##ONLY IMPLEMENTED MSE LOSS AND 0,1 LABELS for now # initialize the network first time, to compute NTK randnnn = numpy.random.random_integers(np.iinfo(np.int32).min, high=np.iinfo(np.int32).max, size=2)[0] key = random.PRNGKey(randnnn) _, params = init_fn(key, (-1, 784)) # Create an MSE predictor to solve the NTK equation in function space. # we assume that the NTK is approximately the same for any sample of parameters (true in the limit of infinite width) sys.stdout.flush() g_dd = pickle.load(open("ntk_train_" + str(FLAGS.train_size) + ".p", "rb")) g_td = pickle.load( open("ntk_train_test_" + str(FLAGS.train_size) + ".p", "rb")) print("Got NTK") if not using_SGD: predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td) batch_size = FLAGS.batch_size from mpi4py import MPI comm = MPI.COMM_WORLD rank = comm.Get_rank() size = comm.Get_size() print(rank) for i in range(FLAGS.num_samples): if i % (ceil(FLAGS.num_samples / 100)) == 0: print(i) sys.stdout.flush() #reinitialize the network randnnn = numpy.random.random_integers(np.iinfo(np.int32).min, high=np.iinfo(np.int32).max, size=2)[0] key = random.PRNGKey(randnnn) _, params = init_fn(key, (-1, 784)) # Get initial values of the network in function space. fx_train = apply_fn(params, x_train) fx_test = apply_fn(params, x_test) if using_SGD: error = 1 lr = 0.1 lr = nt.predict.max_learning_rate(g_dd) print(lr) # lr *= 0.05 # lr*=1 ntk_train = g_dd.squeeze() ntk_train_test = g_td.squeeze() if batch_size == train_size: indices = np.array(list(range(train_size))) while error >= 0.5: if batch_size != train_size: indices = numpy.random.choice(range(train_size), size=batch_size, replace=False) fx_test = fx_test - lr * np.matmul( ntk_train_test[:, indices], (fx_train[indices] - y_train[indices])) / (2 * batch_size) fx_train = fx_train - lr * np.matmul( ntk_train[:, indices], (fx_train[indices] - y_train[indices])) / (2 * batch_size) # fx_train = jax.ops.index_add(fx_train, indices, -lr*np.matmul(ntk_train[:,indices],(fx_train[indices]-y_train[indices]))/(2*batch_size)) # print(fx_train[0:10]) error = np.dot( (fx_train - y_train).squeeze(), (fx_train - y_train).squeeze()) / (2 * train_size) #print(error) else: # Get predictions from analytic computation. fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test) OUTPUT = fx_test > 0.5 OUTPUT = OUTPUT.astype(int) #print(np.transpose(OUTPUT)) fun = ''.join([str(int(i)) for i in OUTPUT]) fun TRUE_OUTPUT = y_test > 0.5 TRUE_OUTPUT = TRUE_OUTPUT.astype(int) #print(np.transpose(OUTPUT)) ''.join([str(int(i)) for i in TRUE_OUTPUT]) print("Generalization accuracy", np.sum(OUTPUT == TRUE_OUTPUT) / FLAGS.test_size) loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) #util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss) #util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss) OUTPUT = fx_train > 0.5 OUTPUT = OUTPUT.astype(int) #print(np.transpose(OUTPUT)) ''.join([str(int(i)) for i in OUTPUT]) TRUE_OUTPUT = y_train > 0.5 TRUE_OUTPUT = TRUE_OUTPUT.astype(int) #print(np.transpose(OUTPUT)) ''.join([str(int(i)) for i in TRUE_OUTPUT]) print("Training accuracy", np.sum(OUTPUT == TRUE_OUTPUT) / FLAGS.train_size) assert np.all(OUTPUT == TRUE_OUTPUT) file = open('results/data_{}_large.txt'.format(rank), 'a') file.write(fun + '\n') file.close()
import inspect from functools import partial import jax.numpy as jnp from jax import jit from onnx_jax.handlers.backend_handler import BackendHandler from onnx_jax.handlers.handler import onnx_op from onnx_jax.pb_wrapper import OnnxNode int32_max = jnp.iinfo(jnp.int32).max @onnx_op("Slice") class Slice(BackendHandler): @classmethod def _common(cls, node: OnnxNode, **kwargs): cls._rewrite(node) cls._prepare(node) def _slice(x, starts, ends, axes=None, steps=None): if axes is not None: axes = tuple(axes) if steps is not None: steps = tuple(steps) ends = [x if x < int32_max else int32_max for x in ends] return onnx_slice(x, tuple(starts), tuple(ends), axes, steps) return _slice @classmethod
def _safe_int32_increment(count): chex.assert_type(count, jnp.int32) max_int32_value = jnp.iinfo(jnp.int32).max one = jnp.array(1, dtype=jnp.int32) return jnp.where(count < max_int32_value, count + one, max_int32_value)
def main(unused_argv): loss = FLAGS.loss train_size = FLAGS.train_size x_train, y_train, x_test, y_test = pickle.load( open("data_" + str(train_size) + ".p", "rb")) print("Got data") sys.stdout.flush() # Build the network init_fn, apply_fn, _ = stax.serial(stax.Dense(512, 1., 0.0357), stax.Relu(), stax.Dense(512, 1., 0.0357), stax.Relu(), stax.Dense(1, 1., 0.0357)) opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) #opt_init, opt_apply, get_params = optimizers.adam(0.001) # Create an mse loss function and a gradient function. if loss == "mse": loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) decision_threshold = 0.5 elif loss == "ce": #loss = lambda fx, yhat: np.mean( (yhat)*np.log(1+np.exp(-fx)) + (1-yhat)*(np.log(1+np.exp(-fx))+fx) ) loss = lambda fx, y: np.sum( np.max(fx, 0) - fx * y + np.log(1 + np.exp(-np.abs(fx)))) decision_threshold = 0.0 else: raise NotImplementedError() grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y))) batch_size = FLAGS.batch_size from mpi4py import MPI comm = MPI.COMM_WORLD rank = comm.Get_rank() size = comm.Get_size() print(rank) for i in range(FLAGS.num_samples): #reinitialize the network randnnn = numpy.random.random_integers(np.iinfo(np.int32).min, high=np.iinfo(np.int32).max, size=2)[0] key = random.PRNGKey(randnnn) _, params = init_fn(key, (-1, 784)) state = opt_init(params) # Get initial values of the network in function space. fx_train = apply_fn(params, x_train) # fx_test = apply_fn(params, x_test) OUTPUT = (fx_train > 0).astype(int) TRUE_OUTPUT = (y_train > 0).astype(int) train_acc = np.sum(OUTPUT == TRUE_OUTPUT) / FLAGS.train_size #train_acc_batch = 0 while train_acc < 1.0: if batch_size != train_size: indices = numpy.random.choice(range(train_size), size=batch_size, replace=False) else: indices = np.array(list(range(train_size))) x_batch = x_train[indices] y_batch = y_train[indices] state = opt_apply(i, grad_loss(params, x_batch, y_batch), state) params = get_params(state) fx_train_batch = apply_fn(params, x_batch) OUTPUT_batch = (fx_train_batch > decision_threshold).astype(int) TRUE_OUTPUT_batch = (y_batch > decision_threshold).astype(int) train_acc_batch = np.sum( OUTPUT_batch == TRUE_OUTPUT_batch) / FLAGS.batch_size #print(fx_train_batch) print(train_acc_batch) if train_acc_batch == 1.0: fx_train = apply_fn(params, x_train) OUTPUT = (fx_train > decision_threshold).astype(int) TRUE_OUTPUT = (y_train > decision_threshold).astype(int) train_acc = np.sum(OUTPUT == TRUE_OUTPUT) / FLAGS.train_size print("train_acc", train_acc) fx_train = apply_fn(params, x_train) fx_test = apply_fn(params, x_test) OUTPUT = (fx_train > decision_threshold).astype(int) #print(np.transpose(OUTPUT)) # ''.join([str(int(i)) for i in OUTPUT]) TRUE_OUTPUT = (y_train > decision_threshold).astype(int) train_acc = np.sum(OUTPUT == TRUE_OUTPUT) / FLAGS.train_size print("Training accuracy", train_acc) assert train_acc == 1.0 OUTPUT = fx_test > decision_threshold OUTPUT = OUTPUT.astype(int) fun = ''.join([str(int(i)) for i in OUTPUT]) TRUE_OUTPUT = y_test > decision_threshold TRUE_OUTPUT = TRUE_OUTPUT.astype(int) ''.join([str(int(i)) for i in TRUE_OUTPUT]) test_acc = np.sum(OUTPUT == TRUE_OUTPUT) / FLAGS.test_size print("Generalization accuracy", test_acc) file = open('data_{}_large.txt'.format(rank), 'a') file.write(fun + '\n') file.close()
def main(unused_argv): # Build data pipelines. print('Loading data.') x_train, y_train, x_test, y_test = \ datasets.mnist(FLAGS.train_size, FLAGS.test_size) # x_train import numpy # numpy.argmax(y_train,1)%2 # y_train_tmp = numpy.zeros((y_train.shape[0],2)) # y_train_tmp[np.arange(y_train.shape[0]),numpy.argmax(y_train,1)%2] = 1 # y_train = y_train_tmp # y_test_tmp = numpy.zeros((y_test.shape[0],2)) # y_test_tmp[np.arange(y_train.shape[0]),numpy.argmax(y_test,1)%2] = 1 # y_test = y_test_tmp y_train_tmp = numpy.argmax(y_train, 1) % 2 y_train = np.expand_dims(y_train_tmp, 1) y_test_tmp = numpy.argmax(y_test, 1) % 2 y_test = np.expand_dims(y_test_tmp, 1) # print(y_train) # Build the network # init_fn, apply_fn, _ = stax.serial( # stax.Dense(2048, 1., 0.05), # # stax.Erf(), # stax.Relu(), # stax.Dense(2048, 1., 0.05), # # stax.Erf(), # stax.Relu(), # stax.Dense(10, 1., 0.05)) init_fn, apply_fn, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(), stax.Dense(1, 1., 0.05)) # key = random.PRNGKey(0) randnnn = numpy.random.random_integers(np.iinfo(np.int32).min, high=np.iinfo(np.int32).max, size=2)[0] key = random.PRNGKey(randnnn) _, params = init_fn(key, (-1, 784)) # params # Create and initialize an optimizer. opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate) state = opt_init(params) # state # Create an mse loss function and a gradient function. loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2) grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0) g_dd = ntk(x_train, None, params) g_td = ntk(x_test, x_train, params) predictor = nt.predict.gradient_descent_mse(g_dd, y_train, g_td) # g_dd.shape # Get initial values of the network in function space. fx_train = apply_fn(params, x_train) fx_test = apply_fn(params, x_test) # Train the network. train_steps = int(FLAGS.train_time // FLAGS.learning_rate) print('Training for {} steps'.format(train_steps)) for i in range(train_steps): params = get_params(state) state = opt_apply(i, grad_loss(params, x_train, y_train), state) # Get predictions from analytic computation. print('Computing analytic prediction.') # fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test) fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test) # Print out summary data comparing the linear / nonlinear model. util.print_summary('train', y_train, apply_fn(params, x_train), fx_train, loss) util.print_summary('test', y_test, apply_fn(params, x_test), fx_test, loss)
def _safesub(x, y): try: finfo = np.finfo(y.dtype) except ValueError: finfo = np.iinfo(y.dtype) return x + np.clip(-y, a_min=None, a_max=finfo.max)
def _safediv(x, y): try: finfo = np.finfo(y.dtype) except ValueError: finfo = np.iinfo(y.dtype) return x * np.clip(np.reciprocal(y), a_min=None, a_max=finfo.max)