Exemple #1
0
    def net():
        p = Parameter(lambda key: np.zeros((1, )))
        a = p()
        b = parameter((2, ), zeros)
        c = parameter((3, ), zeros)
        d = parameter((4, ), zeros)
        e = parameter((5, ), zeros)
        f = parameter((6, ), zeros)

        # must not mess up order (decided by first submodule call):
        k = p()

        return np.concatenate([a, f]) + np.concatenate(
            [b, e]) + np.concatenate([c, d]) + k
Exemple #2
0
    def conv_or_conv_transpose(inputs):
        V = parameter(filter_shape + (inputs.shape[-1], out_chan), normal(.05), 'V')

        example_out = apply(inputs, V=V, g=np.ones(out_chan), b=np.zeros(out_chan))

        # TODO remove need for `.aval.val` when capturing variables in initializer function:
        g = Parameter(lambda key: init_scale /
                                  np.sqrt(np.var(example_out.aval.val, (0, 1, 2)) + 1e-10), 'g')()
        b = Parameter(lambda key: np.mean(example_out.aval.val, (0, 1, 2)) * g.aval.val, 'b')()

        return apply(inputs, V, b, g)
Exemple #3
0
    def dense(inputs):
        V = parameter((out_chan, inputs.shape[1]), randn(stddev=.05), inputs,
                      'V')

        # TODO apply = vmap(apply, (0, None, None, None))
        example_output = lambda: apply(
            inputs, V, g=np.ones(out_chan), b=np.zeros(out_chan))

        g = Parameter(
            lambda rng: init_scale / np.sqrt(
                np.var(example_output(), 0) + 1e-10), 'g')(inputs)
        b = Parameter(lambda rng: np.mean(example_output(), 0) * g,
                      'b')(inputs)
        return apply(inputs, V, g, b)
Exemple #4
0
    def wrapper(dummy_inputs):
        a = parameter((1,), zeros, dummy_inputs)
        b = parameter((2,), zeros, dummy_inputs)
        c = parameter((3,), zeros, dummy_inputs)
        d = parameter((4,), zeros, dummy_inputs)
        e = parameter((5,), zeros, dummy_inputs)
        f = parameter((6,), zeros, dummy_inputs)

        return np.concatenate([a, f]) + np.concatenate([b, e]) + np.concatenate([c, d])
Exemple #5
0
 def net(input_dict):
     return input_dict['a'] * input_dict['b'] * parameter((), zeros)
Exemple #6
0
 def net(inputs):
     assert isinstance(inputs, type)
     return inputs[0] * inputs[1] * parameter((), zeros)
Exemple #7
0
 def dense(inputs):
     return linear_map(inputs) + parameter((2, ), zeros, 'bias')
Exemple #8
0
 def dense(inputs):
     bias = parameter((2, ), zeros, 'bias')
     kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
     return np.dot(inputs, kernel) + bias
Exemple #9
0
    def dense(inputs):
        a = parameter((), randn(), inputs, 'a')
        b = parameter((), randn(), inputs, 'b')

        return a + b
Exemple #10
0
 def linear_map(inputs):
     kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
     return np.dot(inputs, kernel)
Exemple #11
0
 def dense(inputs):
     kernel = parameter((inputs.shape[-1], out_dim), kernel_init)
     bias = parameter((out_dim,), bias_init)
     return np.dot(inputs, kernel) + bias
Exemple #12
0
 def net(inputs):
     return parameter((), lambda key, shape: 2 * jnp.ones(shape))
Exemple #13
0
 def learnable_scale(params):
     return 2 * parameter((), ones, params) * params
Exemple #14
0
 def net(inputs):
     return parameter((), lambda rng, shape: 2 * np.ones(shape), inputs)
Exemple #15
0
    def loss(inputs):
        a = parameter((), ones, inputs, 'a')
        b = parameter((), lambda rng, shape: 2 * np.ones(shape), inputs, 'b')

        return a + b
Exemple #16
0
 def unbatched_dense(input):
     kernel = parameter((out_dim, input.shape[-1]), ones)
     bias = parameter((out_dim, ), ones)
     return np.dot(kernel, input) + bias
Exemple #17
0
 def net():
     return parameter((), lambda key, shape: 2 * np.ones(shape))
Exemple #18
0
 def net(inputs):
     return inputs, inputs * parameter((), zeros)
Exemple #19
0
 def cell(carry, x):
     scale = parameter((2, ), zeros)
     return {
         'a': scale * np.array([2]) * carry['a'] * x
     }, scale * np.array([2]) * carry['a'] * x
Exemple #20
0
    def dense():
        a = parameter((), normal(), 'a')
        b = parameter((), normal(), 'b')

        return a + b
Exemple #21
0
 def net(input_dict):
     return input_dict[0] * input_dict[1] * parameter((), zeros, input_dict[0])
Exemple #22
0
 def cell(carry, x):
     scale = parameter((2, ), zeros)
     return scale * np.array([2]) * carry * x, scale * np.array(
         [2]) * carry * x
Exemple #23
0
    def loss(inputs):
        a = parameter((), ones, 'a')
        b = parameter((), lambda key, shape: 2 * jnp.ones(shape), 'b')

        return a + b