def test_Dense_equivalent(): class DenseEquivalent: def __init__(self, out_dim, kernel_init=glorot_normal(), bias_init=normal()): self.bias_init = bias_init self.kernel_init = kernel_init self.out_dim = out_dim def apply(self, params, inputs): kernel, bias = params return jnp.dot(inputs, kernel) + bias def init_parameters(self, example_inputs, key): kernel_key, bias_key = random.split(key, 2) kernel = self.kernel_init(kernel_key, (example_inputs.shape[-1], self.out_dim)) bias = self.bias_init(bias_key, (self.out_dim, )) return namedtuple('dense', ['kernel', 'bias'])(kernel=kernel, bias=bias) def shaped(self, example_inputs): return ShapedParametrized(self, example_inputs) test_Dense_shape(DenseEquivalent)
def test_parameter_Dense_equivalent(): def DenseEquivalent(out_dim, kernel_init=glorot_normal(), bias_init=normal()): @parametrized def dense(inputs): kernel = Parameter(lambda key: kernel_init(key, (inputs.shape[-1], out_dim)))() bias = Parameter(lambda key: bias_init(key, (out_dim,)))() return np.dot(inputs, kernel) + bias return dense test_Dense_shape(DenseEquivalent)
def test_parameter_Dense_equivalent(): def Dense(out_dim, kernel_init=glorot(), bias_init=randn()): @parametrized def dense(inputs): kernel = Parameter(lambda rng: kernel_init(rng, (inputs.shape[-1], out_dim)))(inputs) bias = Parameter(lambda rng: bias_init(rng, (out_dim, )))(inputs) return np.dot(inputs, kernel) + bias return dense test_Dense_shape(Dense)
def test_Dense_equivalent(): class Dense: Params = namedtuple('dense', ['kernel', 'bias']) def __init__(self, out_dim, kernel_init=glorot(), bias_init=randn()): self.bias_init = bias_init self.kernel_init = kernel_init self.out_dim = out_dim def apply(self, params, inputs): kernel, bias = params return np.dot(inputs, kernel) + bias def init_parameters(self, rng, example_inputs): rng_kernel, rng_bias = random.split(rng, 2) kernel = self.kernel_init(rng_kernel, (example_inputs.shape[-1], self.out_dim)) bias = self.bias_init(rng_bias, (self.out_dim, )) return Dense.Params(kernel=kernel, bias=bias) def shaped(self, example_inputs): return ShapedParametrized(self, example_inputs) test_Dense_shape(Dense)