Пример #1
0
    def testFunctionInterop(self):
        x = np.asarray(3.0)
        y = np.asarray(2.0)

        add = lambda x, y: x + y
        add_fn = tf.function(add)

        raw_result = add(x, y)
        fn_result = add_fn(x, y)

        self.assertIsInstance(raw_result, np.ndarray)
        self.assertIsInstance(fn_result, np.ndarray)
        self.assertAllClose(raw_result, fn_result)
Пример #2
0
    def testJacobian(self):
        with tf.GradientTape() as g:
            x = np.asarray([1., 2.])
            y = np.asarray([3., 4.])
            g.watch(x)
            g.watch(y)
            z = x * x * y

        jacobian = g.jacobian(z, [x, y])
        answer = [tf.linalg.diag(2 * x * y), tf.linalg.diag(x * x)]

        self.assertIsInstance(jacobian[0], np.ndarray)
        self.assertIsInstance(jacobian[1], np.ndarray)
        self.assertAllClose(jacobian, answer)
Пример #3
0
def normal(key, shape, dtype=tf.float32):
    """Sample standard-normal random values.

  Args:
    key: the RNG key.
    shape: the shape of the result.
    dtype: the dtype of the result.

  Returns:
    Random values in standard-normal distribution.
  """
    key = tf_np.asarray(key, dtype=_RNG_KEY_DTYPE)
    return tf_np.asarray(
        tf.random.stateless_normal(shape, seed=_key2seed(key), dtype=dtype))
Пример #4
0
  def ntk_fn(x1: np.ndarray,
             x2: Optional[np.ndarray],
             params: PyTree,
             **apply_fn_kwargs) -> np.ndarray:
    """Computes a single sample of the empirical NTK (jacobian outer product).

    Args:
      x1:
        first batch of inputs.
      x2:
        second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a
        matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`.
      params:
        A `PyTree` of parameters about which we would like to compute the
        neural tangent kernel.
      **apply_fn_kwargs:
        keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split
        into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `_split_kwargs`
        function which will be passed to `apply_fn`. In particular, the rng key
        in `apply_fn_kwargs`, will be split into two different (if `x1!=x2`) or
        same (if `x1==x2`) rng keys. See the `_read_key` function for more
        details.

    Returns:
      A single sample of the empirical NTK. The shape of the kernel is "almost"
      `zip(f(x1).shape, f(x2).shape)` except for:
      1) `trace_axes` are absent as they are contracted over.
      2) `diagonal_axes` are present only once.
      All other axes are present twice.
    """

    apply_fn_kwargs1, apply_fn_kwargs2 = _split_kwargs(apply_fn_kwargs, x1, x2)
    f1 = _get_f_params(f, x1, **apply_fn_kwargs1)
    with tf.GradientTape() as tape:
      tape.watch(params)
      y = f1(params)
    j1 = np.asarray(tape.jacobian(y, params))
    if x2 is None:
      j2 = j1
    else:
      f2 = _get_f_params(f, x2, **apply_fn_kwargs2)
      with tf.GradientTape() as tape:
        tape.watch(params)
        y = f2(params)
      j2 = np.asarray(tape.jacobian(y, params))

    fx1 = eval_on_shapes(f1)(params)
    ntk = sum_and_contract(j1, j2, fx1.ndim)
    return ntk / utils.size_at(fx1, trace_axes)
Пример #5
0
    def testBatchJacobian(self):
        with tf.GradientTape() as g:
            x = np.asarray([[1., 2.], [3., 4.]])
            y = np.asarray([[3., 4.], [5., 6.]])
            g.watch(x)
            g.watch(y)
            z = x * x * y

        batch_jacobian = g.batch_jacobian(z, x)
        answer = tf.stack(
            [tf.linalg.diag(2 * x[0] * y[0]),
             tf.linalg.diag(2 * x[1] * y[1])])

        self.assertIsInstance(batch_jacobian, np.ndarray)
        self.assertAllClose(batch_jacobian, answer)
Пример #6
0
    def testPyFuncInterop(self):
        def py_func_fn(a, b):
            return a + b

        @tf.function
        def fn(a, b):
            result = tf.py_function(py_func_fn, [a, b], a.dtype)
            return np.asarray(result)

        a = np.asarray(1.)
        b = np.asarray(2.)

        result = fn(a, b)
        self.assertIsInstance(result, np.ndarray)
        self.assertAllClose(result, 3.)
    def testAxes(self, diagonal_axes, trace_axes):
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=3)
        key = splits[0]
        self_split = splits[1]
        other_split = splits[2]
        data_self = np.asarray(normal((4, 5, 6, 3), seed=self_split))
        data_other = np.asarray(normal((2, 5, 6, 3), seed=other_split))

        _diagonal_axes = utils.canonicalize_axis(diagonal_axes, data_self)
        _trace_axes = utils.canonicalize_axis(trace_axes, data_self)

        if any(d == c for d in _diagonal_axes for c in _trace_axes):
            raise absltest.SkipTest(
                'diagonal axes must be different from channel axes.')

        implicit, direct, nngp = KERNELS['empirical_logits_3'](
            key, (5, 6, 3),
            CONV,
            diagonal_axes=diagonal_axes,
            trace_axes=trace_axes)

        n_marg = len(_diagonal_axes)
        n_chan = len(_trace_axes)

        g = implicit(data_self, None)
        g_direct = direct(data_self, None)
        g_nngp = nngp(data_self, None)

        self.assertAllClose(g, g_direct)
        self.assertEqual(g_nngp.shape, g.shape)
        self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg, g_nngp.ndim)

        if 0 not in _trace_axes and 0 not in _diagonal_axes:
            g = implicit(data_other, data_self)
            g_direct = direct(data_other, data_self)
            g_nngp = nngp(data_other, data_self)

            self.assertAllClose(g, g_direct)
            self.assertEqual(g_nngp.shape, g.shape)
            self.assertEqual(2 * (data_self.ndim - n_chan) - n_marg,
                             g_nngp.ndim)
Пример #8
0
    def testGradientTapeInterop(self):
        with tf.GradientTape() as t:
            x = np.asarray(3.0)
            y = np.asarray(2.0)

            t.watch([x, y])

            xx = 2 * x
            yy = 3 * y

        dx, dy = t.gradient([xx, yy], [x, y])

        self.assertIsInstance(dx, np.ndarray)
        self.assertIsInstance(dy, np.ndarray)
        self.assertAllClose(dx, 2.0)
        self.assertAllClose(dy, 3.0)
Пример #9
0
    def test_tf_conv_general_dilated(self, lhs_shape, rhs_shape, strides,
                                     padding, lhs_dilation, rhs_dilation,
                                     feature_group_count, batch_group_count,
                                     dimension_numbers, perms):
        tf.print("dimension_numbers: {}".format(dimension_numbers),
                 output_stream=sys.stdout)
        lhs_perm, rhs_perm = perms  # permute to compatible shapes

        lhs_tf = tfnp.transpose(tfnp.ones(lhs_shape), lhs_perm)
        rhs_tf = tfnp.transpose(tfnp.ones(rhs_shape), rhs_perm)

        lhs_jax = jnp.transpose(jnp.ones(lhs_shape), lhs_perm)
        rhs_jax = jnp.transpose(jnp.ones(rhs_shape), rhs_perm)

        jax_conv = jax.lax.conv_general_dilated(
            lhs_jax, rhs_jax, strides, padding, lhs_dilation, rhs_dilation,
            dimension_numbers, feature_group_count, batch_group_count)

        tf_conv = lax.conv_general_dilated(lhs_tf, rhs_tf, strides, padding,
                                           jax_conv.shape, lhs_dilation,
                                           rhs_dilation, dimension_numbers,
                                           feature_group_count,
                                           batch_group_count)

        self.assertAllEqual(tf_conv, tfnp.asarray(jax_conv))
Пример #10
0
    def testTFNPArrayTFOpInterop(self):
        arr = np.asarray(10.)

        # TODO(nareshmodi): Test more ops.
        sq = tf.square(arr)
        self.assertIsInstance(sq, tf.Tensor)
        self.assertEqual(100., sq.numpy())
Пример #11
0
    def testDistStratInterop(self):
        strategy = tf.distribute.MirroredStrategy(
            devices=['CPU:0', 'CPU:1', 'CPU:2'])

        multiplier = np.asarray(5.)

        @tf.function
        def run():
            ctx = tf.distribute.get_replica_context()
            val = np.asarray(ctx.replica_id_in_sync_group)
            return val * multiplier

        distributed_values = strategy.run(run)
        reduced = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                  distributed_values,
                                  axis=None)

        values = strategy.experimental_local_results(distributed_values)

        # Note that this should match the number of virtual CPUs.
        self.assertLen(values, 3)
        self.assertIsInstance(values[0], np.ndarray)
        self.assertIsInstance(values[1], np.ndarray)
        self.assertIsInstance(values[2], np.ndarray)
        self.assertAllClose(values[0], 0)
        self.assertAllClose(values[1], 5)
        self.assertAllClose(values[2], 10)

        # "strategy.reduce" doesn't rewrap in ndarray.
        # self.assertIsInstance(reduced, np.ndarray)
        self.assertAllClose(reduced, 15)
Пример #12
0
    def testTFNPArrayNPOpInterop(self):
        arr = np.asarray([10.])

        # TODO(nareshmodi): Test more ops.
        sq = onp.square(arr)
        self.assertIsInstance(sq, onp.ndarray)
        self.assertEqual(100., sq[0])
Пример #13
0
    def testDatasetInterop(self):
        values = [1, 2, 3, 4, 5, 6]
        values_as_array = np.asarray(values)

        # Tensor dataset
        dataset = tf.data.Dataset.from_tensors(values_as_array)

        for value, value_from_dataset in zip([values_as_array], dataset):
            self.assertIsInstance(value_from_dataset, np.ndarray)
            self.assertAllEqual(value_from_dataset, value)

        # Tensor slice dataset
        dataset = tf.data.Dataset.from_tensor_slices(values_as_array)

        for value, value_from_dataset in zip(values, dataset):
            self.assertIsInstance(value_from_dataset, np.ndarray)
            self.assertAllEqual(value_from_dataset, value)

        # # TODO(nareshmodi): as_numpy_iterator() doesn't work.
        # items = list(dataset.as_numpy_iterator())

        # Map over a dataset.
        dataset = dataset.map(lambda x: np.add(x, 1))

        for value, value_from_dataset in zip(values, dataset):
            self.assertIsInstance(value_from_dataset, np.ndarray)
            self.assertAllEqual(value_from_dataset, value + 1)

        # Batch a dataset.
        dataset = tf.data.Dataset.from_tensor_slices(values_as_array).batch(2)

        for value, value_from_dataset in zip([[1, 2], [3, 4], [5, 6]],
                                             dataset):
            self.assertIsInstance(value_from_dataset, np.ndarray)
            self.assertAllEqual(value_from_dataset, value)
Пример #14
0
    def testIndex(self):
        @tf.function
        def f(x):
            return [0, 1][x]

        with self.assertRaises(TypeError):
            f(np.asarray([1]))
Пример #15
0
def logsumexp(x, axis=None, keepdims=None):
    """Computes log(sum(exp(elements across dimensions of a tensor))).

  Reduces `x` along the dimensions given in `axis`.
  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
  entry in `axis`. If `keepdims` is true, the reduced dimensions
  are retained with length 1.
  If `axis` has no entries, all dimensions are reduced, and a
  tensor with a single element is returned.
  This function is more numerically stable than log(sum(exp(input))). It avoids
  overflows caused by taking the exp of large inputs and underflows caused by
  taking the log of small inputs.

  Args:
    x: The tensor to reduce. Should have numeric type.
    axis: The dimensions to reduce. If `None` (the default), reduces all
      dimensions. Must be in the range `[-rank(x), rank(x))`.
    keepdims: If true, retains reduced dimensions with length 1.

  Returns:
    The reduced tensor.
  """
    return tf_np.asarray(
        tf.math.reduce_logsumexp(input_tensor=x.data,
                                 axis=axis,
                                 keepdims=keepdims))
Пример #16
0
def sort_key_val(keys, values, dimension=-1):
    """Sorts keys along a dimension and applies same permutation to values.

  Args:
    keys: an array. The dtype must be comparable numbers (integers and reals).
    values: an array, with the same shape of `keys`.
    dimension: an `int`. The dimension along which to sort.

  Returns:
    Permuted keys and values.
  """
    keys = tf_np.asarray(keys)
    values = tf_np.asarray(values)
    rank = keys.data.shape.ndims
    if rank is None:
        rank = values.data.shape.ndims
    if rank is None:
        # We need to know the rank because tf.gather requires batch_dims to be `int`
        raise ValueError(
            "The rank of either keys or values must be known, but "
            "both are unknown (i.e. their shapes are both None).")
    if dimension in (-1, rank - 1):

        def maybe_swapaxes(a):
            return a
    else:

        def maybe_swapaxes(a):
            return tf_np.swapaxes(a, dimension, -1)

    # We need to swap axes because tf.gather (and tf.gather_nd) supports
    # batch_dims on the left but not on the right.
    # TODO(wangpeng): Investigate whether we should do swapaxes or moveaxis.
    keys = maybe_swapaxes(keys)
    values = maybe_swapaxes(values)
    idxs = tf_np.argsort(keys)
    idxs = idxs.data

    # Using tf.gather rather than np.take because the former supports batch_dims
    def gather(a):
        return tf_np.asarray(tf.gather(a.data, idxs, batch_dims=rank - 1))

    keys = gather(keys)
    values = gather(values)
    keys = maybe_swapaxes(keys)
    values = maybe_swapaxes(values)
    return keys, values
Пример #17
0
    def testMapFn(self):
        x = np.asarray([1., 2.])
        mapped_x = tf.map_fn(lambda x: (x[0] + 1, x[1] + 1), (x, x))

        self.assertIsInstance(mapped_x[0], np.ndarray)
        self.assertIsInstance(mapped_x[1], np.ndarray)
        self.assertAllClose(mapped_x[0], [2., 3.])
        self.assertAllClose(mapped_x[1], [2., 3.])
Пример #18
0
    def testGradientDescentMseEnsembleTrain(self):
        key = stateless_uniform(shape=[2],
                                seed=[1, 1],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        x = np.asarray(normal((8, 4, 6, 3), seed=key))
        _, _, kernel_fn = stax.serial(stax.Conv(1, (2, 2)), stax.Relu(),
                                      stax.Conv(1, (2, 1)))
        y = np.asarray(normal((8, 2, 5, 1), seed=key))
        predictor = predict.gradient_descent_mse_ensemble(kernel_fn, x, y)

        for t in [None, np.array([0., 1., 10.])]:
            with self.subTest(t=t):
                y_none = predictor(t, None, None, compute_cov=True)
                y_x = predictor(t, x, None, compute_cov=True)
                self._assertAllClose(y_none, y_x, 0.04)
Пример #19
0
  def testGradientTapeInterop(self):
    with tf.GradientTape() as t:
      x = np.asarray(3.0)
      y = np.asarray(2.0)

      t.watch([x, y])

      xx = 2 * x
      yy = 3 * y

    dx, dy = t.gradient([xx, yy], [x, y])

    # # TODO(nareshmodi): Figure out a way to rewrap ndarray as tensors.
    # self.assertIsInstance(dx, np.ndarray)
    # self.assertIsInstance(dy, np.ndarray)
    self.assertAllClose(dx, 2.0)
    self.assertAllClose(dy, 3.0)
Пример #20
0
    def _f(*args):
        tf_args = tf.nest.map_structure(lambda x: tf_np.asarray(x).data, args)

        def tf_f(x):
            return f(*x)

        outputs = tf.vectorized_map(tf_f, tf_args)
        return tf.nest.map_structure(tf_np.asarray, outputs)
Пример #21
0
    def testIter(self):
        @tf.function
        def f(x):
            y, z = x
            return y, z

        with self.assertRaises(TypeError):
            f(np.asarray([3, 4]))
Пример #22
0
def conv_general_dilated(lhs, rhs, window_strides, padding, output_shape, lhs_dilation=None,
                         rhs_dilation=None, dimension_numbers=None,
                         feature_group_count=1, batch_group_count=1, precision=None):
  """ A general conv API that integrates normal conv, deconvolution,
  dilated convolution, etc."""
  dim = None
  lhs_spec, rhs_spec, out_spec = dimension_numbers
  if lhs_spec != out_spec:
    raise TypeError('Current implementation requires the `data_format` of the '
                    'inputs and outputs to be the same.')
  if len(lhs_spec) >= 6:
    raise TypeError('Current implmentation does not support 4 or higher'
                    'dimensional convolution, but got: ', len(lhs_spec) - 2)
  dim = len(lhs_spec) - 2
  if lhs_dilation and rhs_dilation:
    if lhs_dilation == (1,) * dim and rhs_dilation == (1,) * dim:
      lhs_dilation, rhs_dilation = None, None
    else:
      raise TypeError('Current implementation does not support that '
                      'deconvolution and dilation to be performed at the same '
                      'time, but got lhs_dilation: {}, rhs_dilation: {}'.format(
                          lhs_dilation, rhs_dilation))
  if padding not in ['SAME', 'VALID']:
    raise TypeError('Current implementation requires the padding parameter'
                    'to be either `VALID` or `SAME`, but got: ', padding)
  # Convert params from int/Sequence[int] to list of ints.
  strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter(
    window_strides, lhs_dilation, rhs_dilation
  )
  # Preprocess the shapes
  dim_maps = {}
  if isinstance(lhs_spec, str):
    dim_maps['I'] = list(rhs_spec).index('I')
    dim_maps['O'] = list(rhs_spec).index('O')
    dim_maps['N'] = list(lhs_spec).index('N')
    dim_maps['C'] = list(lhs_spec).index('C')
  else:
    dim_maps['I'] = rhs_spec[1]
    dim_maps['O'] = rhs_spec[0]
    dim_maps['N'] = lhs_spec[0]
    dim_maps['C'] = lhs_spec[1]

  lhs = np.moveaxis(lhs, (dim_maps['N'], dim_maps['C']), (0, dim + 1))
  # Adjust the filters, put the dimension 'I' and 'O' at last.
  rhs = np.moveaxis(rhs, (dim_maps['O'], dim_maps['I']), (dim + 1, dim))
  spatial_dim_maps = {1: 'W', 2: 'HW', 3: 'DHW'}
  data_format = 'N' + spatial_dim_maps[dim] + 'C'
  tf_nn_APIs = {1: [nn.conv1d, nn.conv1d_transpose],
                2: [nn.conv2d, nn.conv2d_transpose],
                3: [nn.conv3d, nn.conv3d_transpose]}

  output = None
  if rhs_dilation or (lhs_dilation is None and rhs_dilation is None):
    output = tf_nn_APIs[dim][0](lhs, rhs, strides, padding, data_format, rhs_dilation)
  else:
    output = tf_nn_APIs[dim][1](lhs, rhs, tf.constant(output_shape), strides, padding, data_format, lhs_dilation)
  output = np.moveaxis(output, (0, dim + 1), (dim_maps['N'], dim_maps['C']))
  return np.asarray(output)
Пример #23
0
    def testSerial(self, train_shape, test_shape, network, name, kernel_fn,
                   batch_size):
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(key, 3)
        key = keys[0]
        self_split = keys[1]
        other_split = keys[2]
        data_self = np.asarray(normal(train_shape, seed=self_split))
        data_other = np.asarray(normal(test_shape, seed=other_split))
        kernel_fn = kernel_fn(key, train_shape[1:], network)
        kernel_batched = batch._serial(kernel_fn, batch_size=batch_size)

        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other)
Пример #24
0
    def testLen(self):
        @tf.function
        def f(x):
            # Note that shape of input to len is data dependent.
            return len(np.where(x)[0])

        t = np.asarray([True, False, True])
        with self.assertRaises(TypeError):
            f(t)
    def testLinearization(self, shape):
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=4)
        key = splits[0]
        s1 = splits[1]
        s2 = splits[2]
        s3 = splits[3]
        w1 = np.asarray(normal(shape, seed=s1))
        w1 = 0.5 * (w1 + w1.T)
        w2 = np.asarray(normal(shape, seed=s2))
        b = np.asarray(normal((shape[-1], ), seed=s3))
        params = (w1, w2, b)

        splits = tf_random_split(seed=tf.convert_to_tensor(key,
                                                           dtype=tf.int32),
                                 num=2)
        key = splits[0]
        split = splits[1]
        x0 = np.asarray(normal((shape[-1], ), seed=split))

        f_lin = empirical.linearize(EmpiricalTest.f, x0)

        for _ in range(TAYLOR_RANDOM_SAMPLES):
            for do_alter in [True, False]:
                for do_shift_x in [True, False]:
                    splits = tf_random_split(seed=tf.convert_to_tensor(
                        key, dtype=tf.int32),
                                             num=2)
                    key = splits[0]
                    split = splits[1]
                    x = np.asarray(normal((shape[-1], ), seed=split))
                    self.assertAllClose(
                        EmpiricalTest.f_lin_exact(x0,
                                                  x,
                                                  params,
                                                  do_alter,
                                                  do_shift_x=do_shift_x),
                        f_lin(x, params, do_alter, do_shift_x=do_shift_x))
Пример #26
0
def train(batch_size, learning_rate, num_training_iters, validation_steps):
    """ training loop """
    # Loading the MNIST Dataset
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

    x_train = np.asarray(tf.reshape(x_train, (-1, 784))).astype(np.float32)
    y_train = np.asarray(tf.one_hot(y_train, 10)).astype(np.float32)
    x_test = np.asarray(tf.reshape(x_test, (-1, 784))).astype(np.float32)
    y_test = np.asarray(tf.one_hot(y_test, 10)).astype(np.float32)

    mnist_train = tf.data.Dataset.from_tensor_slices(
        (x_train, y_train)).batch(batch_size)
    mnist_test = tf.data.Dataset.from_tensor_slices(
        (x_test, y_test)).batch(batch_size)

    print('Initialized MNIST with {} training and {} test examples.'.format(
        x_train.shape[0], x_test.shape[0]))

    model = Model([128, 32], learning_rate=learning_rate)

    # The training loop
    loss = 0
    for i in range(num_training_iters):
        for x, y in mnist_train:
            loss += model.train(x, y)

        # Calculate and print the train and test accuracy
        if not (i + 1) % validation_steps:
            correct_train_predictions = 0
            for train_x, train_y in mnist_train:
                correct_train_predictions += model.evaluate(train_x, train_y)

            correct_test_predictions = 0
            for test_x, test_y in mnist_test:
                correct_test_predictions += model.evaluate(test_x, test_y)

            print('[{}] Loss: {}, train acc: {}, test acc: {}'.format(
                i + 1, round(float(loss.data / validation_steps), 4),
                round(correct_train_predictions / x_train.shape[0], 4),
                round(correct_test_predictions / x_test.shape[0], 4)))

            loss = 0
Пример #27
0
    def testParallel(self, train_shape, test_shape, network, name, kernel_fn):
        test_utils.stub_out_pmap(batch, 2)
        key = stateless_uniform(shape=[2],
                                seed=[0, 0],
                                minval=None,
                                maxval=None,
                                dtype=tf.int32)
        keys = tf_random_split(key, 3)
        key = keys[0]
        self_split = keys[1]
        other_split = keys[2]
        data_self = np.asarray(normal(train_shape, seed=self_split))
        data_other = np.asarray(normal(test_shape, seed=other_split))

        kernel_fn = kernel_fn(key, train_shape[1:], network, use_dropout=False)
        kernel_batched = batch._parallel(kernel_fn)

        _test_kernel_against_batched(self, kernel_fn, kernel_batched,
                                     data_self, data_other, True)
Пример #28
0
    def testTensorTFNPArrayInterop(self):
        arr = np.asarray(0.)
        t = tf.constant(10.)

        arr_plus_t = arr + t
        t_plus_arr = t + arr

        self.assertIsInstance(arr_plus_t, tf.Tensor)
        self.assertIsInstance(t_plus_arr, tf.Tensor)
        self.assertEqual(10., arr_plus_t.numpy())
        self.assertEqual(10., t_plus_arr.numpy())
Пример #29
0
def _get_inputs_and_model(width=1, n_classes=2, use_conv=True):
    key = stateless_uniform(shape=[2],
                            seed=[1, 1],
                            minval=None,
                            maxval=None,
                            dtype=tf.int32)
    keys = tf_random_split(key)
    key = keys[0]
    split = keys[1]
    x1 = np.asarray(normal((8, 4, 3, 2), seed=key))
    x2 = np.asarray(normal((4, 4, 3, 2), seed=split))

    if not use_conv:
        x1 = np.reshape(x1, (x1.shape[0], -1))
        x2 = np.reshape(x2, (x2.shape[0], -1))

    init_fn, apply_fn, kernel_fn = stax.serial(
        stax.Conv(width, (3, 3)) if use_conv else stax.Dense(width),
        stax.Relu(), stax.Flatten(), stax.Dense(n_classes, 2., 0.5))
    return x1, x2, init_fn, apply_fn, kernel_fn, key
Пример #30
0
    def cho_solve(b: np.ndarray, b_axes: Axes) -> np.ndarray:
        b_axes = utils.canonicalize_axis(b_axes, b)
        last_b_axes = range(-len(b_axes), 0)
        x_shape = x_non_channel_shape + tuple(b.shape[a] for a in b_axes)

        b = np.moveaxis(b, b_axes, last_b_axes)
        b = b.reshape((A.shape[1], -1))

        x = np.asarray(tf.linalg.cholesky_solve(C, b))
        x = x.reshape(x_shape)
        return x