コード例 #1
0
    def net():
        p = Parameter(lambda key: np.zeros((1, )))
        a = p()
        b = parameter((2, ), zeros)
        c = parameter((3, ), zeros)
        d = parameter((4, ), zeros)
        e = parameter((5, ), zeros)
        f = parameter((6, ), zeros)

        # must not mess up order (decided by first submodule call):
        k = p()

        return np.concatenate([a, f]) + np.concatenate(
            [b, e]) + np.concatenate([c, d]) + k
コード例 #2
0
    def conv_or_conv_transpose(inputs):
        V = parameter(filter_shape + (inputs.shape[-1], out_chan), normal(.05), 'V')

        example_out = apply(inputs, V=V, g=np.ones(out_chan), b=np.zeros(out_chan))

        # TODO remove need for `.aval.val` when capturing variables in initializer function:
        g = Parameter(lambda key: init_scale /
                                  np.sqrt(np.var(example_out.aval.val, (0, 1, 2)) + 1e-10), 'g')()
        b = Parameter(lambda key: np.mean(example_out.aval.val, (0, 1, 2)) * g.aval.val, 'b')()

        return apply(inputs, V, b, g)
コード例 #3
0
ファイル: pixelcnn.py プロジェクト: yueyedeai/jaxnet
    def dense(inputs):
        V = parameter((out_chan, inputs.shape[1]), randn(stddev=.05), inputs,
                      'V')

        # TODO apply = vmap(apply, (0, None, None, None))
        example_output = lambda: apply(
            inputs, V, g=np.ones(out_chan), b=np.zeros(out_chan))

        g = Parameter(
            lambda rng: init_scale / np.sqrt(
                np.var(example_output(), 0) + 1e-10), 'g')(inputs)
        b = Parameter(lambda rng: np.mean(example_output(), 0) * g,
                      'b')(inputs)
        return apply(inputs, V, g, b)
コード例 #4
0
    def wrapper(dummy_inputs):
        a = parameter((1,), zeros, dummy_inputs)
        b = parameter((2,), zeros, dummy_inputs)
        c = parameter((3,), zeros, dummy_inputs)
        d = parameter((4,), zeros, dummy_inputs)
        e = parameter((5,), zeros, dummy_inputs)
        f = parameter((6,), zeros, dummy_inputs)

        return np.concatenate([a, f]) + np.concatenate([b, e]) + np.concatenate([c, d])
コード例 #5
0
 def net(input_dict):
     return input_dict['a'] * input_dict['b'] * parameter((), zeros)
コード例 #6
0
 def net(inputs):
     assert isinstance(inputs, type)
     return inputs[0] * inputs[1] * parameter((), zeros)
コード例 #7
0
 def dense(inputs):
     return linear_map(inputs) + parameter((2, ), zeros, 'bias')
コード例 #8
0
 def dense(inputs):
     bias = parameter((2, ), zeros, 'bias')
     kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
     return np.dot(inputs, kernel) + bias
コード例 #9
0
    def dense(inputs):
        a = parameter((), randn(), inputs, 'a')
        b = parameter((), randn(), inputs, 'b')

        return a + b
コード例 #10
0
 def linear_map(inputs):
     kernel = parameter((inputs.shape[-1], 2), zeros, 'kernel')
     return np.dot(inputs, kernel)
コード例 #11
0
ファイル: test_examples.py プロジェクト: tom-bird/jaxnet
 def dense(inputs):
     kernel = parameter((inputs.shape[-1], out_dim), kernel_init)
     bias = parameter((out_dim,), bias_init)
     return np.dot(inputs, kernel) + bias
コード例 #12
0
ファイル: test_modules.py プロジェクト: juliuskunze/jaxnet
 def net(inputs):
     return parameter((), lambda key, shape: 2 * jnp.ones(shape))
コード例 #13
0
ファイル: test_modules.py プロジェクト: QUELUCIFER/jaxnet
 def learnable_scale(params):
     return 2 * parameter((), ones, params) * params
コード例 #14
0
ファイル: test_modules.py プロジェクト: QUELUCIFER/jaxnet
 def net(inputs):
     return parameter((), lambda rng, shape: 2 * np.ones(shape), inputs)
コード例 #15
0
ファイル: test_modules.py プロジェクト: QUELUCIFER/jaxnet
    def loss(inputs):
        a = parameter((), ones, inputs, 'a')
        b = parameter((), lambda rng, shape: 2 * np.ones(shape), inputs, 'b')

        return a + b
コード例 #16
0
ファイル: test_modules.py プロジェクト: j-towns/jaxnet
 def unbatched_dense(input):
     kernel = parameter((out_dim, input.shape[-1]), ones)
     bias = parameter((out_dim, ), ones)
     return np.dot(kernel, input) + bias
コード例 #17
0
ファイル: test_modules.py プロジェクト: j-towns/jaxnet
 def net():
     return parameter((), lambda key, shape: 2 * np.ones(shape))
コード例 #18
0
 def net(inputs):
     return inputs, inputs * parameter((), zeros)
コード例 #19
0
 def cell(carry, x):
     scale = parameter((2, ), zeros)
     return {
         'a': scale * np.array([2]) * carry['a'] * x
     }, scale * np.array([2]) * carry['a'] * x
コード例 #20
0
    def dense():
        a = parameter((), normal(), 'a')
        b = parameter((), normal(), 'b')

        return a + b
コード例 #21
0
 def net(input_dict):
     return input_dict[0] * input_dict[1] * parameter((), zeros, input_dict[0])
コード例 #22
0
 def cell(carry, x):
     scale = parameter((2, ), zeros)
     return scale * np.array([2]) * carry * x, scale * np.array(
         [2]) * carry * x
コード例 #23
0
ファイル: test_modules.py プロジェクト: juliuskunze/jaxnet
    def loss(inputs):
        a = parameter((), ones, 'a')
        b = parameter((), lambda key, shape: 2 * jnp.ones(shape), 'b')

        return a + b