示例#1
0
文件: layers.py 项目: SymJAX/SymJAX
    def __init__(self, input, crop_shape, deterministic, padding=0, seed=None):

        # if given only a scalar
        if not hasattr(padding, "__len__"):
            pad_shape = [(padding, padding)] * (input.ndim - 1)
        # else
        else:
            pad_shape = [(pad, pad) if not hasattr(pad, "__len__") else pad
                         for pad in padding]

        assert len(pad_shape) == len(crop_shape)
        assert len(pad_shape) == input.ndim - 1

        start_indices = list()
        fixed_indices = list()
        for i, (pad, dim,
                crop) in enumerate(zip(pad_shape, input.shape[1:],
                                       crop_shape)):
            maxval = pad[0] + pad[1] + dim - crop
            start_indices.append(
                T.random.randint(
                    minval=0,
                    maxval=maxval,
                    shape=(input.shape[0], 1),
                    dtype="int32",
                    seed=seed + i if seed is not None else seed,
                ))

            fixed_indices.append(
                T.ones((input.shape[0], 1), "int32") * (maxval // 2))
        start_indices = T.concatenate(start_indices, 1)
        fixed_indices = T.concatenate(fixed_indices, 1)

        dirac = T.cast(deterministic, "float32")

        # pad the input
        pinput = T.pad(input, [(0, 0)] + pad_shape)

        routput = T.map(
            lambda x, indices: T.dynamic_slice(x, indices, crop_shape),
            sequences=[pinput, start_indices],
        )
        doutput = T.map(
            lambda x, indices: T.dynamic_slice(x, indices, crop_shape),
            sequences=[pinput, fixed_indices],
        )

        return doutput * dirac + (1 - dirac) * routput
示例#2
0
def test_map():
    sj.current_graph().reset()
    w = T.Variable(1.0, dtype="float32")
    u = T.Placeholder((), "float32")
    out = T.map(lambda a, w, u: (u - w) * a, [T.range(3)],
                non_sequences=[w, u])
    f = sj.function(u, outputs=out, updates={w: w + 1})
    assert np.array_equal(f(2), np.arange(3))
    assert np.array_equal(f(2), np.zeros(3))
    assert np.array_equal(f(0), -np.arange(3) * 3)
示例#3
0
def test_grad_map():
    sj.current_graph().reset()
    w = T.Variable(1.0, dtype="float32")
    u = T.Placeholder((), "float32", name="u")
    out = T.map(lambda a, w, u: w * a * u, (T.range(3), ),
                non_sequences=(w, u))
    g = sj.gradients(out.sum(), w)
    f = sj.function(u, outputs=g)

    assert np.array_equal(f(0), 0)
    assert np.array_equal(f(1), 3)
示例#4
0

# Now let's do a simple map for which we can compute a simple
# moving average. The for loop will consist of moving a window and
# average the values on that window

# in that case the function also needs to be defined


def fn(window):
    # the function first input is the current index of the for loop
    # the other inputs are the (ordered) sequences and non_sequnces
    # values

    return T.mean(window)


windowed = T.extract_signal_patches(signal, 10)
output = T.map(fn, sequences=[windowed])
f = symjax.function(signal, outputs=output)

fig, ax = plt.subplots(1, 1, figsize=(5, 2))

ax.plot(x, c="b")
ax.plot(f(x), c="r")
ax.set_title("SMA: 10")
ax.set_xticks([])
ax.set_yticks([])

plt.tight_layout()
示例#5
0
import jax
import numpy as np
import sys
sys.path.insert(0, "../")

import symjax
import symjax.tensor as T

# map
xx = T.ones(10)
a = T.map(lambda a: a * 2, xx)
g = symjax.gradients(a.sum(), xx)[0]
f = symjax.function(outputs=[a, g])

# scan
xx = T.ones(10) * 2
a = T.scan(lambda c, x: (c * x, c * x), T.ones(1), xx)
g = symjax.gradients(a[1][-1], xx)[0]
f = symjax.function(outputs=[a, g])

# scan with updates
xx = T.range(5)
uu = T.ones((10, 2))
vvar = T.Variable(T.zeros((10, 2)))
vv = T.index_add(vvar, 1, 1)
a = T.scan(lambda c, x, p: (T.index_update(c, x, p[x]), 1), vv, xx, [vv])
#a = T.scan(lambda c, x: (c*x,c*x), T.ones(1), xx)
#a = T.scan(lambda c, x: (T.square(c),c[0]), uu, xx)
#g = symjax.gradients(a[1][-1],xx)
f = symjax.function(outputs=a[0], updates={vvar: vvar + 1})
print(f(), f(), f())
示例#6
0
def test_grad_map_v2():
    sj.current_graph().reset()
    out = T.map(lambda a, b: a * b, (T.range(3), T.range(3)))
    f = sj.function(outputs=out)

    assert np.array_equal(f(), np.arange(3) * np.arange(3))
示例#7
0
import symjax as sj
import symjax.tensor as T

w = T.Variable(1.0, dtype="float32")
u = T.Placeholder((), "float32")
out = T.map(lambda a, w, u: (u - w) * a, [T.range(3)], non_sequences=[w, u])
f = sj.function(u, outputs=out, updates={w: w + 1})
print(f(2))
# [0, 1, 2]
print(f(2))
# [0, 0, 0]
print(f(0))
# [0, -3, -6]


w.reset()
out = T.map(lambda a, w, u: w * a * u, [T.range(3)], non_sequences=[w, u])
g = sj.gradients(out.sum(), [w])[0]
f = sj.function(u, outputs=g)

print(f(0))
# 0
print(f(1))
# 3


out = T.map(lambda a, b: a * b, [T.range(3), T.range(3)])
f = sj.function(outputs=out)

print(f())
# [0, 1, 4]