def __init__(self, input, crop_shape, deterministic, padding=0, seed=None): # if given only a scalar if not hasattr(padding, "__len__"): pad_shape = [(padding, padding)] * (input.ndim - 1) # else else: pad_shape = [(pad, pad) if not hasattr(pad, "__len__") else pad for pad in padding] assert len(pad_shape) == len(crop_shape) assert len(pad_shape) == input.ndim - 1 start_indices = list() fixed_indices = list() for i, (pad, dim, crop) in enumerate(zip(pad_shape, input.shape[1:], crop_shape)): maxval = pad[0] + pad[1] + dim - crop start_indices.append( T.random.randint( minval=0, maxval=maxval, shape=(input.shape[0], 1), dtype="int32", seed=seed + i if seed is not None else seed, )) fixed_indices.append( T.ones((input.shape[0], 1), "int32") * (maxval // 2)) start_indices = T.concatenate(start_indices, 1) fixed_indices = T.concatenate(fixed_indices, 1) dirac = T.cast(deterministic, "float32") # pad the input pinput = T.pad(input, [(0, 0)] + pad_shape) routput = T.map( lambda x, indices: T.dynamic_slice(x, indices, crop_shape), sequences=[pinput, start_indices], ) doutput = T.map( lambda x, indices: T.dynamic_slice(x, indices, crop_shape), sequences=[pinput, fixed_indices], ) return doutput * dirac + (1 - dirac) * routput
def test_map(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") u = T.Placeholder((), "float32") out = T.map(lambda a, w, u: (u - w) * a, [T.range(3)], non_sequences=[w, u]) f = sj.function(u, outputs=out, updates={w: w + 1}) assert np.array_equal(f(2), np.arange(3)) assert np.array_equal(f(2), np.zeros(3)) assert np.array_equal(f(0), -np.arange(3) * 3)
def test_grad_map(): sj.current_graph().reset() w = T.Variable(1.0, dtype="float32") u = T.Placeholder((), "float32", name="u") out = T.map(lambda a, w, u: w * a * u, (T.range(3), ), non_sequences=(w, u)) g = sj.gradients(out.sum(), w) f = sj.function(u, outputs=g) assert np.array_equal(f(0), 0) assert np.array_equal(f(1), 3)
# Now let's do a simple map for which we can compute a simple # moving average. The for loop will consist of moving a window and # average the values on that window # in that case the function also needs to be defined def fn(window): # the function first input is the current index of the for loop # the other inputs are the (ordered) sequences and non_sequnces # values return T.mean(window) windowed = T.extract_signal_patches(signal, 10) output = T.map(fn, sequences=[windowed]) f = symjax.function(signal, outputs=output) fig, ax = plt.subplots(1, 1, figsize=(5, 2)) ax.plot(x, c="b") ax.plot(f(x), c="r") ax.set_title("SMA: 10") ax.set_xticks([]) ax.set_yticks([]) plt.tight_layout()
import jax import numpy as np import sys sys.path.insert(0, "../") import symjax import symjax.tensor as T # map xx = T.ones(10) a = T.map(lambda a: a * 2, xx) g = symjax.gradients(a.sum(), xx)[0] f = symjax.function(outputs=[a, g]) # scan xx = T.ones(10) * 2 a = T.scan(lambda c, x: (c * x, c * x), T.ones(1), xx) g = symjax.gradients(a[1][-1], xx)[0] f = symjax.function(outputs=[a, g]) # scan with updates xx = T.range(5) uu = T.ones((10, 2)) vvar = T.Variable(T.zeros((10, 2))) vv = T.index_add(vvar, 1, 1) a = T.scan(lambda c, x, p: (T.index_update(c, x, p[x]), 1), vv, xx, [vv]) #a = T.scan(lambda c, x: (c*x,c*x), T.ones(1), xx) #a = T.scan(lambda c, x: (T.square(c),c[0]), uu, xx) #g = symjax.gradients(a[1][-1],xx) f = symjax.function(outputs=a[0], updates={vvar: vvar + 1}) print(f(), f(), f())
def test_grad_map_v2(): sj.current_graph().reset() out = T.map(lambda a, b: a * b, (T.range(3), T.range(3))) f = sj.function(outputs=out) assert np.array_equal(f(), np.arange(3) * np.arange(3))
import symjax as sj import symjax.tensor as T w = T.Variable(1.0, dtype="float32") u = T.Placeholder((), "float32") out = T.map(lambda a, w, u: (u - w) * a, [T.range(3)], non_sequences=[w, u]) f = sj.function(u, outputs=out, updates={w: w + 1}) print(f(2)) # [0, 1, 2] print(f(2)) # [0, 0, 0] print(f(0)) # [0, -3, -6] w.reset() out = T.map(lambda a, w, u: w * a * u, [T.range(3)], non_sequences=[w, u]) g = sj.gradients(out.sum(), [w])[0] f = sj.function(u, outputs=g) print(f(0)) # 0 print(f(1)) # 3 out = T.map(lambda a, b: a * b, [T.range(3), T.range(3)]) f = sj.function(outputs=out) print(f()) # [0, 1, 4]