def testFunctionInterop(self): x = np.asarray(3.0) y = np.asarray(2.0) add = lambda x, y: x + y add_fn = tf.function(add) raw_result = add(x, y) fn_result = add_fn(x, y) self.assertIsInstance(raw_result, np.ndarray) self.assertIsInstance(fn_result, np.ndarray) self.assertAllClose(raw_result, fn_result)
def testJacobian(self): with tf.GradientTape() as g: x = np.asarray([1., 2.]) y = np.asarray([3., 4.]) g.watch(x) g.watch(y) z = x * x * y jacobian = g.jacobian(z, [x, y]) answer = [tf.linalg.diag(2 * x * y), tf.linalg.diag(x * x)] self.assertIsInstance(jacobian[0], np.ndarray) self.assertIsInstance(jacobian[1], np.ndarray) self.assertAllClose(jacobian, answer)
def normal(key, shape, dtype=tf.float32): """Sample standard-normal random values. Args: key: the RNG key. shape: the shape of the result. dtype: the dtype of the result. Returns: Random values in standard-normal distribution. """ key = tf_np.asarray(key, dtype=_RNG_KEY_DTYPE) return tf_np.asarray( tf.random.stateless_normal(shape, seed=_key2seed(key), dtype=dtype))
def ntk_fn(x1: np.ndarray, x2: Optional[np.ndarray], params: PyTree, **apply_fn_kwargs) -> np.ndarray: """Computes a single sample of the empirical NTK (jacobian outer product). Args: x1: first batch of inputs. x2: second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`. params: A `PyTree` of parameters about which we would like to compute the neural tangent kernel. **apply_fn_kwargs: keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `_split_kwargs` function which will be passed to `apply_fn`. In particular, the rng key in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or same (if `x1==x2`) rng keys. See the `_read_key` function for more details. Returns: A single sample of the empirical NTK. The shape of the kernel is "almost" `zip(f(x1).shape, f(x2).shape)` except for: 1) `trace_axes` are absent as they are contracted over. 2) `diagonal_axes` are present only once. All other axes are present twice. """ apply_fn_kwargs1, apply_fn_kwargs2 = _split_kwargs(apply_fn_kwargs, x1, x2) f1 = _get_f_params(f, x1, **apply_fn_kwargs1) with tf.GradientTape() as tape: tape.watch(params) y = f1(params) j1 = np.asarray(tape.jacobian(y, params)) if x2 is None: j2 = j1 else: f2 = _get_f_params(f, x2, **apply_fn_kwargs2) with tf.GradientTape() as tape: tape.watch(params) y = f2(params) j2 = np.asarray(tape.jacobian(y, params)) fx1 = eval_on_shapes(f1)(params) ntk = sum_and_contract(j1, j2, fx1.ndim) return ntk / utils.size_at(fx1, trace_axes)
def testBatchJacobian(self): with tf.GradientTape() as g: x = np.asarray([[1., 2.], [3., 4.]]) y = np.asarray([[3., 4.], [5., 6.]]) g.watch(x) g.watch(y) z = x * x * y batch_jacobian = g.batch_jacobian(z, x) answer = tf.stack( [tf.linalg.diag(2 * x[0] * y[0]), tf.linalg.diag(2 * x[1] * y[1])]) self.assertIsInstance(batch_jacobian, np.ndarray) self.assertAllClose(batch_jacobian, answer)
def testPyFuncInterop(self): def py_func_fn(a, b): return a + b @tf.function def fn(a, b): result = tf.py_function(py_func_fn, [a, b], a.dtype) return np.asarray(result) a = np.asarray(1.) b = np.asarray(2.) result = fn(a, b) self.assertIsInstance(result, np.ndarray) self.assertAllClose(result, 3.)
def testAxes(self, diagonal_axes, trace_axes): key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) splits = tf_random_split(seed=tf.convert_to_tensor(key, dtype=tf.int32), num=3) key = splits[0] self_split = splits[1] other_split = splits[2] data_self = np.asarray(normal((4, 5, 6, 3), seed=self_split)) data_other = np.asarray(normal((2, 5, 6, 3), seed=other_split)) _diagonal_axes = utils.canonicalize_axis(diagonal_axes, data_self) _trace_axes = utils.canonicalize_axis(trace_axes, data_self) if any(d == c for d in _diagonal_axes for c in _trace_axes): raise absltest.SkipTest( 'diagonal axes must be different from channel axes.') implicit, direct, nngp = KERNELS['empirical_logits_3']( key, (5, 6, 3), CONV, diagonal_axes=diagonal_axes, trace_axes=trace_axes) n_marg = len(_diagonal_axes) n_chan = len(_trace_axes) g = implicit(data_self, None) g_direct = direct(data_self, None) g_nngp = nngp(data_self, None) self.assertAllClose(g, g_direct) self.assertEqual(g_nngp.shape, g.shape) self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim) if 0 not in _trace_axes and 0 not in _diagonal_axes: g = implicit(data_other, data_self) g_direct = direct(data_other, data_self) g_nngp = nngp(data_other, data_self) self.assertAllClose(g, g_direct) self.assertEqual(g_nngp.shape, g.shape) self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim)
def testGradientTapeInterop(self): with tf.GradientTape() as t: x = np.asarray(3.0) y = np.asarray(2.0) t.watch([x, y]) xx = 2 * x yy = 3 * y dx, dy = t.gradient([xx, yy], [x, y]) self.assertIsInstance(dx, np.ndarray) self.assertIsInstance(dy, np.ndarray) self.assertAllClose(dx, 2.0) self.assertAllClose(dy, 3.0)
def test_tf_conv_general_dilated(self, lhs_shape, rhs_shape, strides, padding, lhs_dilation, rhs_dilation, feature_group_count, batch_group_count, dimension_numbers, perms): tf.print("dimension_numbers: {}".format(dimension_numbers), output_stream=sys.stdout) lhs_perm, rhs_perm = perms # permute to compatible shapes lhs_tf = tfnp.transpose(tfnp.ones(lhs_shape), lhs_perm) rhs_tf = tfnp.transpose(tfnp.ones(rhs_shape), rhs_perm) lhs_jax = jnp.transpose(jnp.ones(lhs_shape), lhs_perm) rhs_jax = jnp.transpose(jnp.ones(rhs_shape), rhs_perm) jax_conv = jax.lax.conv_general_dilated( lhs_jax, rhs_jax, strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count) tf_conv = lax.conv_general_dilated(lhs_tf, rhs_tf, strides, padding, jax_conv.shape, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, batch_group_count) self.assertAllEqual(tf_conv, tfnp.asarray(jax_conv))
def testTFNPArrayTFOpInterop(self): arr = np.asarray(10.) # TODO(nareshmodi): Test more ops. sq = tf.square(arr) self.assertIsInstance(sq, tf.Tensor) self.assertEqual(100., sq.numpy())
def testDistStratInterop(self): strategy = tf.distribute.MirroredStrategy( devices=['CPU:0', 'CPU:1', 'CPU:2']) multiplier = np.asarray(5.) @tf.function def run(): ctx = tf.distribute.get_replica_context() val = np.asarray(ctx.replica_id_in_sync_group) return val * multiplier distributed_values = strategy.run(run) reduced = strategy.reduce(tf.distribute.ReduceOp.SUM, distributed_values, axis=None) values = strategy.experimental_local_results(distributed_values) # Note that this should match the number of virtual CPUs. self.assertLen(values, 3) self.assertIsInstance(values[0], np.ndarray) self.assertIsInstance(values[1], np.ndarray) self.assertIsInstance(values[2], np.ndarray) self.assertAllClose(values[0], 0) self.assertAllClose(values[1], 5) self.assertAllClose(values[2], 10) # "strategy.reduce" doesn't rewrap in ndarray. # self.assertIsInstance(reduced, np.ndarray) self.assertAllClose(reduced, 15)
def testTFNPArrayNPOpInterop(self): arr = np.asarray([10.]) # TODO(nareshmodi): Test more ops. sq = onp.square(arr) self.assertIsInstance(sq, onp.ndarray) self.assertEqual(100., sq[0])
def testDatasetInterop(self): values = [1, 2, 3, 4, 5, 6] values_as_array = np.asarray(values) # Tensor dataset dataset = tf.data.Dataset.from_tensors(values_as_array) for value, value_from_dataset in zip([values_as_array], dataset): self.assertIsInstance(value_from_dataset, np.ndarray) self.assertAllEqual(value_from_dataset, value) # Tensor slice dataset dataset = tf.data.Dataset.from_tensor_slices(values_as_array) for value, value_from_dataset in zip(values, dataset): self.assertIsInstance(value_from_dataset, np.ndarray) self.assertAllEqual(value_from_dataset, value) # # TODO(nareshmodi): as_numpy_iterator() doesn't work. # items = list(dataset.as_numpy_iterator()) # Map over a dataset. dataset = dataset.map(lambda x: np.add(x, 1)) for value, value_from_dataset in zip(values, dataset): self.assertIsInstance(value_from_dataset, np.ndarray) self.assertAllEqual(value_from_dataset, value + 1) # Batch a dataset. dataset = tf.data.Dataset.from_tensor_slices(values_as_array).batch(2) for value, value_from_dataset in zip([[1, 2], [3, 4], [5, 6]], dataset): self.assertIsInstance(value_from_dataset, np.ndarray) self.assertAllEqual(value_from_dataset, value)
def testIndex(self): @tf.function def f(x): return [0, 1][x] with self.assertRaises(TypeError): f(np.asarray([1]))
def logsumexp(x, axis=None, keepdims=None): """Computes log(sum(exp(elements across dimensions of a tensor))). Reduces `x` along the dimensions given in `axis`. Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each entry in `axis`. If `keepdims` is true, the reduced dimensions are retained with length 1. If `axis` has no entries, all dimensions are reduced, and a tensor with a single element is returned. This function is more numerically stable than log(sum(exp(input))). It avoids overflows caused by taking the exp of large inputs and underflows caused by taking the log of small inputs. Args: x: The tensor to reduce. Should have numeric type. axis: The dimensions to reduce. If `None` (the default), reduces all dimensions. Must be in the range `[-rank(x), rank(x))`. keepdims: If true, retains reduced dimensions with length 1. Returns: The reduced tensor. """ return tf_np.asarray( tf.math.reduce_logsumexp(input_tensor=x.data, axis=axis, keepdims=keepdims))
def sort_key_val(keys, values, dimension=-1): """Sorts keys along a dimension and applies same permutation to values. Args: keys: an array. The dtype must be comparable numbers (integers and reals). values: an array, with the same shape of `keys`. dimension: an `int`. The dimension along which to sort. Returns: Permuted keys and values. """ keys = tf_np.asarray(keys) values = tf_np.asarray(values) rank = keys.data.shape.ndims if rank is None: rank = values.data.shape.ndims if rank is None: # We need to know the rank because tf.gather requires batch_dims to be `int` raise ValueError( "The rank of either keys or values must be known, but " "both are unknown (i.e. their shapes are both None).") if dimension in (-1, rank - 1): def maybe_swapaxes(a): return a else: def maybe_swapaxes(a): return tf_np.swapaxes(a, dimension, -1) # We need to swap axes because tf.gather (and tf.gather_nd) supports # batch_dims on the left but not on the right. # TODO(wangpeng): Investigate whether we should do swapaxes or moveaxis. keys = maybe_swapaxes(keys) values = maybe_swapaxes(values) idxs = tf_np.argsort(keys) idxs = idxs.data # Using tf.gather rather than np.take because the former supports batch_dims def gather(a): return tf_np.asarray(tf.gather(a.data, idxs, batch_dims=rank - 1)) keys = gather(keys) values = gather(values) keys = maybe_swapaxes(keys) values = maybe_swapaxes(values) return keys, values
def testMapFn(self): x = np.asarray([1., 2.]) mapped_x = tf.map_fn(lambda x: (x[0] + 1, x[1] + 1), (x, x)) self.assertIsInstance(mapped_x[0], np.ndarray) self.assertIsInstance(mapped_x[1], np.ndarray) self.assertAllClose(mapped_x[0], [2., 3.]) self.assertAllClose(mapped_x[1], [2., 3.])
def testGradientDescentMseEnsembleTrain(self): key = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) x = np.asarray(normal((8, 4, 6, 3), seed=key)) _, _, kernel_fn = stax.serial(stax.Conv(1, (2, 2)), stax.Relu(), stax.Conv(1, (2, 1))) y = np.asarray(normal((8, 2, 5, 1), seed=key)) predictor = predict.gradient_descent_mse_ensemble(kernel_fn, x, y) for t in [None, np.array([0., 1., 10.])]: with self.subTest(t=t): y_none = predictor(t, None, None, compute_cov=True) y_x = predictor(t, x, None, compute_cov=True) self._assertAllClose(y_none, y_x, 0.04)
def testGradientTapeInterop(self): with tf.GradientTape() as t: x = np.asarray(3.0) y = np.asarray(2.0) t.watch([x, y]) xx = 2 * x yy = 3 * y dx, dy = t.gradient([xx, yy], [x, y]) # # TODO(nareshmodi): Figure out a way to rewrap ndarray as tensors. # self.assertIsInstance(dx, np.ndarray) # self.assertIsInstance(dy, np.ndarray) self.assertAllClose(dx, 2.0) self.assertAllClose(dy, 3.0)
def _f(*args): tf_args = tf.nest.map_structure(lambda x: tf_np.asarray(x).data, args) def tf_f(x): return f(*x) outputs = tf.vectorized_map(tf_f, tf_args) return tf.nest.map_structure(tf_np.asarray, outputs)
def testIter(self): @tf.function def f(x): y, z = x return y, z with self.assertRaises(TypeError): f(np.asarray([3, 4]))
def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, lhs_dilation=None, rhs_dilation=None, dimension_numbers=None, feature_group_count=1, batch_group_count=1, precision=None): """ A general conv API that integrates normal conv, deconvolution, dilated convolution, etc.""" dim = None lhs_spec, rhs_spec, out_spec = dimension_numbers if lhs_spec != out_spec: raise TypeError('Current implementation requires the `data_format` of the ' 'inputs and outputs to be the same.') if len(lhs_spec) >= 6: raise TypeError('Current implmentation does not support 4 or higher' 'dimensional convolution, but got: ', len(lhs_spec) - 2) dim = len(lhs_spec) - 2 if lhs_dilation and rhs_dilation: if lhs_dilation == (1,) * dim and rhs_dilation == (1,) * dim: lhs_dilation, rhs_dilation = None, None else: raise TypeError('Current implementation does not support that ' 'deconvolution and dilation to be performed at the same ' 'time, but got lhs_dilation: {}, rhs_dilation: {}'.format( lhs_dilation, rhs_dilation)) if padding not in ['SAME', 'VALID']: raise TypeError('Current implementation requires the padding parameter' 'to be either `VALID` or `SAME`, but got: ', padding) # Convert params from int/Sequence[int] to list of ints. strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter( window_strides, lhs_dilation, rhs_dilation ) # Preprocess the shapes dim_maps = {} if isinstance(lhs_spec, str): dim_maps['I'] = list(rhs_spec).index('I') dim_maps['O'] = list(rhs_spec).index('O') dim_maps['N'] = list(lhs_spec).index('N') dim_maps['C'] = list(lhs_spec).index('C') else: dim_maps['I'] = rhs_spec[1] dim_maps['O'] = rhs_spec[0] dim_maps['N'] = lhs_spec[0] dim_maps['C'] = lhs_spec[1] lhs = np.moveaxis(lhs, (dim_maps['N'], dim_maps['C']), (0, dim + 1)) # Adjust the filters, put the dimension 'I' and 'O' at last. rhs = np.moveaxis(rhs, (dim_maps['O'], dim_maps['I']), (dim + 1, dim)) spatial_dim_maps = {1: 'W', 2: 'HW', 3: 'DHW'} data_format = 'N' + spatial_dim_maps[dim] + 'C' tf_nn_APIs = {1: [nn.conv1d, nn.conv1d_transpose], 2: [nn.conv2d, nn.conv2d_transpose], 3: [nn.conv3d, nn.conv3d_transpose]} output = None if rhs_dilation or (lhs_dilation is None and rhs_dilation is None): output = tf_nn_APIs[dim][0](lhs, rhs, strides, padding, data_format, rhs_dilation) else: output = tf_nn_APIs[dim][1](lhs, rhs, tf.constant(output_shape), strides, padding, data_format, lhs_dilation) output = np.moveaxis(output, (0, dim + 1), (dim_maps['N'], dim_maps['C'])) return np.asarray(output)
def testSerial(self, train_shape, test_shape, network, name, kernel_fn, batch_size): key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key, 3) key = keys[0] self_split = keys[1] other_split = keys[2] data_self = np.asarray(normal(train_shape, seed=self_split)) data_other = np.asarray(normal(test_shape, seed=other_split)) kernel_fn = kernel_fn(key, train_shape[1:], network) kernel_batched = batch._serial(kernel_fn, batch_size=batch_size) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other)
def testLen(self): @tf.function def f(x): # Note that shape of input to len is data dependent. return len(np.where(x)[0]) t = np.asarray([True, False, True]) with self.assertRaises(TypeError): f(t)
def testLinearization(self, shape): key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) splits = tf_random_split(seed=tf.convert_to_tensor(key, dtype=tf.int32), num=4) key = splits[0] s1 = splits[1] s2 = splits[2] s3 = splits[3] w1 = np.asarray(normal(shape, seed=s1)) w1 = 0.5 * (w1 + w1.T) w2 = np.asarray(normal(shape, seed=s2)) b = np.asarray(normal((shape[-1], ), seed=s3)) params = (w1, w2, b) splits = tf_random_split(seed=tf.convert_to_tensor(key, dtype=tf.int32), num=2) key = splits[0] split = splits[1] x0 = np.asarray(normal((shape[-1], ), seed=split)) f_lin = empirical.linearize(EmpiricalTest.f, x0) for _ in range(TAYLOR_RANDOM_SAMPLES): for do_alter in [True, False]: for do_shift_x in [True, False]: splits = tf_random_split(seed=tf.convert_to_tensor( key, dtype=tf.int32), num=2) key = splits[0] split = splits[1] x = np.asarray(normal((shape[-1], ), seed=split)) self.assertAllClose( EmpiricalTest.f_lin_exact(x0, x, params, do_alter, do_shift_x=do_shift_x), f_lin(x, params, do_alter, do_shift_x=do_shift_x))
def train(batch_size, learning_rate, num_training_iters, validation_steps): """ training loop """ # Loading the MNIST Dataset mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 x_train = np.asarray(tf.reshape(x_train, (-1, 784))).astype(np.float32) y_train = np.asarray(tf.one_hot(y_train, 10)).astype(np.float32) x_test = np.asarray(tf.reshape(x_test, (-1, 784))).astype(np.float32) y_test = np.asarray(tf.one_hot(y_test, 10)).astype(np.float32) mnist_train = tf.data.Dataset.from_tensor_slices( (x_train, y_train)).batch(batch_size) mnist_test = tf.data.Dataset.from_tensor_slices( (x_test, y_test)).batch(batch_size) print('Initialized MNIST with {} training and {} test examples.'.format( x_train.shape[0], x_test.shape[0])) model = Model([128, 32], learning_rate=learning_rate) # The training loop loss = 0 for i in range(num_training_iters): for x, y in mnist_train: loss += model.train(x, y) # Calculate and print the train and test accuracy if not (i + 1) % validation_steps: correct_train_predictions = 0 for train_x, train_y in mnist_train: correct_train_predictions += model.evaluate(train_x, train_y) correct_test_predictions = 0 for test_x, test_y in mnist_test: correct_test_predictions += model.evaluate(test_x, test_y) print('[{}] Loss: {}, train acc: {}, test acc: {}'.format( i + 1, round(float(loss.data / validation_steps), 4), round(correct_train_predictions / x_train.shape[0], 4), round(correct_test_predictions / x_test.shape[0], 4))) loss = 0
def testParallel(self, train_shape, test_shape, network, name, kernel_fn): test_utils.stub_out_pmap(batch, 2) key = stateless_uniform(shape=[2], seed=[0, 0], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key, 3) key = keys[0] self_split = keys[1] other_split = keys[2] data_self = np.asarray(normal(train_shape, seed=self_split)) data_other = np.asarray(normal(test_shape, seed=other_split)) kernel_fn = kernel_fn(key, train_shape[1:], network, use_dropout=False) kernel_batched = batch._parallel(kernel_fn) _test_kernel_against_batched(self, kernel_fn, kernel_batched, data_self, data_other, True)
def testTensorTFNPArrayInterop(self): arr = np.asarray(0.) t = tf.constant(10.) arr_plus_t = arr + t t_plus_arr = t + arr self.assertIsInstance(arr_plus_t, tf.Tensor) self.assertIsInstance(t_plus_arr, tf.Tensor) self.assertEqual(10., arr_plus_t.numpy()) self.assertEqual(10., t_plus_arr.numpy())
def _get_inputs_and_model(width=1, n_classes=2, use_conv=True): key = stateless_uniform(shape=[2], seed=[1, 1], minval=None, maxval=None, dtype=tf.int32) keys = tf_random_split(key) key = keys[0] split = keys[1] x1 = np.asarray(normal((8, 4, 3, 2), seed=key)) x2 = np.asarray(normal((4, 4, 3, 2), seed=split)) if not use_conv: x1 = np.reshape(x1, (x1.shape[0], -1)) x2 = np.reshape(x2, (x2.shape[0], -1)) init_fn, apply_fn, kernel_fn = stax.serial( stax.Conv(width, (3, 3)) if use_conv else stax.Dense(width), stax.Relu(), stax.Flatten(), stax.Dense(n_classes, 2., 0.5)) return x1, x2, init_fn, apply_fn, kernel_fn, key
def cho_solve(b: np.ndarray, b_axes: Axes) -> np.ndarray: b_axes = utils.canonicalize_axis(b_axes, b) last_b_axes = range(-len(b_axes), 0) x_shape = x_non_channel_shape + tuple(b.shape[a] for a in b_axes) b = np.moveaxis(b, b_axes, last_b_axes) b = b.reshape((A.shape[1], -1)) x = np.asarray(tf.linalg.cholesky_solve(C, b)) x = x.reshape(x_shape) return x