def test_jax_scan_tap_output(): a_aet = scalar("a") def input_step_fn(y_tm1, y_tm3, a): y_tm1.name = "y_tm1" y_tm3.name = "y_tm3" res = (y_tm1 + y_tm3) * a res.name = "y_t" return res y_scan_aet, _ = scan( fn=input_step_fn, outputs_info=[ { "initial": aet.as_tensor_variable(np.r_[-1.0, 1.3, 0.0].astype(config.floatX)), "taps": [-1, -3], }, ], non_sequences=[a_aet], n_steps=10, name="y_scan", ) y_scan_aet.name = "y" y_scan_aet.owner.inputs[0].name = "y_all" out_fg = FunctionGraph([a_aet], [y_scan_aet]) test_input_vals = [np.array(10.0).astype(config.floatX)] compare_jax_and_py(out_fg, test_input_vals)
def outer_step(*args): # Separate the received arguments into their respective (seq, outputs # from previous iterations, nonseqs) categories i_sequences = list(args[:len(o_sequences)]) i_prev_outputs = list(args[len(o_sequences):-len(o_nonsequences)]) i_non_sequences = list(args[-len(o_nonsequences):]) i_outputs_infos = i_prev_outputs + [ None, ] * len(new_nitsots) # Call the user-provided function with the proper arguments results, updates = scan( fn=fn, sequences=i_sequences[:-1], outputs_info=i_outputs_infos, non_sequences=i_non_sequences, name=name + "_inner", n_steps=i_sequences[-1], ) if not isinstance(results, list): results = [results] # Keep only the last timestep of every output but keep all the updates if not isinstance(results, list): return results[-1], updates else: return [r[-1] for r in results], updates
def test_one_sequence_one_output_weights_gpu2(self): def f_rnn(u_t, x_tm1, W_in, W): return u_t * W_in + x_tm1 * W u = fvector("u") x0 = fscalar("x0") W_in = fscalar("win") W = fscalar("w") output, updates = scan( f_rnn, u, x0, [W_in, W], n_steps=None, truncate_gradient=-1, go_backwards=False, mode=mode_with_gpu, ) f2 = aesara.function( [u, x0, W_in, W], output, updates=updates, allow_input_downcast=True, mode=mode_with_gpu, ) # get random initial values rng = np.random.default_rng(utt.fetch_seed()) v_u = rng.uniform(size=(4, ), low=-5.0, high=5.0) v_x0 = rng.uniform() W = rng.uniform() W_in = rng.uniform() # compute the output in numpy v_out = np.zeros((4, )) v_out[0] = v_u[0] * W_in + v_x0 * W for step in range(1, 4): v_out[step] = v_u[step] * W_in + v_out[step - 1] * W aesara_values = f2(v_u, v_x0, W_in, W) utt.assert_allclose(aesara_values, v_out) topo = f2.maker.fgraph.toposort() assert sum([isinstance(node.op, HostFromGpu) for node in topo]) == 1 assert sum([isinstance(node.op, GpuFromHost) for node in topo]) == 4 scan_node = [ node for node in topo if isinstance(node.op, scan.op.Scan) ] assert len(scan_node) == 1 scan_node = scan_node[0] scan_node_topo = scan_node.op.fn.maker.fgraph.toposort() # check that there is no gpu transfer in the inner loop. assert any(isinstance(node.op, GpuElemwise) for node in scan_node_topo) assert not any( isinstance(node.op, HostFromGpu) for node in scan_node_topo) assert not any( isinstance(node.op, GpuFromHost) for node in scan_node_topo)
def test_gpu3_mixture_dtype_outputs(self): def f_rnn(u_t, x_tm1, W_in, W): return (u_t * W_in + x_tm1 * W, aet.cast(u_t + x_tm1, "int64")) u = fvector("u") x0 = fscalar("x0") W_in = fscalar("win") W = fscalar("w") output, updates = scan( f_rnn, u, [x0, None], [W_in, W], n_steps=None, truncate_gradient=-1, go_backwards=False, mode=self.mode_with_gpu, ) f2 = aesara.function( [u, x0, W_in, W], output, updates=updates, allow_input_downcast=True, mode=self.mode_with_gpu, ) # get random initial values rng = np.random.RandomState(utt.fetch_seed()) v_u = rng.uniform(size=(4, ), low=-5.0, high=5.0) v_x0 = rng.uniform() W = rng.uniform() W_in = rng.uniform() # compute the output in numpy v_out1 = np.zeros((4, )) v_out2 = np.zeros((4, ), dtype="int64") v_out1[0] = v_u[0] * W_in + v_x0 * W v_out2[0] = v_u[0] + v_x0 for step in range(1, 4): v_out1[step] = v_u[step] * W_in + v_out1[step - 1] * W v_out2[step] = np.int64(v_u[step] + v_out1[step - 1]) aesara_out1, aesara_out2 = f2(v_u, v_x0, W_in, W) utt.assert_allclose(aesara_out1, v_out1) utt.assert_allclose(aesara_out2, v_out2) topo = f2.maker.fgraph.toposort() scan_node = [node for node in topo if isinstance(node.op, Scan)] assert len(scan_node) == 1 scan_node = scan_node[0] assert self.is_scan_on_gpu(scan_node)
def test_gradient_scan(): # Test for a crash when using MRG inside scan and taking the gradient # See https://groups.google.com/d/msg/theano-dev/UbcYyU5m-M8/UO9UgXqnQP0J aesara_rng = MRG_RandomStream(10) w = shared(np.ones(1, dtype="float32")) def one_step(x): return x + aesara_rng.uniform((1, ), dtype="float32") * w x = vector(dtype="float32") values, updates = scan(one_step, outputs_info=x, n_steps=10) gw = grad(aet_sum(values[-1]), w) f = function([x], gw) f(np.arange(1, dtype="float32"))
def test_gradient_scan(mode): aesara_rng = MRG_RandomStream(10) w = shared(np.ones(1, dtype="float32")) def one_step(x): return x + aesara_rng.uniform((1, ), dtype="float32") * w x = vector(dtype="float32") values, updates = scan(one_step, outputs_info=x, n_steps=10) gw = grad(at_sum(values[-1]), w) f = function([x], gw, mode=mode) assert np.allclose( f(np.arange(1, dtype=np.float32)), np.array([0.13928187], dtype=np.float32), rtol=1e6, )
def test_simple_shared_mrg_random(): aesara_rng = MRG_RandomStream(10) values, updates = scan( lambda: aesara_rng.uniform((2, ), -1, 1), [], [], [], n_steps=5, truncate_gradient=-1, go_backwards=False, ) my_f = function([], values, updates=updates, allow_input_downcast=True) # Just check for run-time errors my_f() my_f()
def setup_method(self): self.k = iscalar("k") self.A = vector("A") result, _ = scan( fn=lambda prior_result, A: prior_result * A, outputs_info=aet.ones_like(self.A), non_sequences=self.A, n_steps=self.k, ) result_check, _ = scan_checkpoints( fn=lambda prior_result, A: prior_result * A, outputs_info=aet.ones_like(self.A), non_sequences=self.A, n_steps=self.k, save_every_N=100, ) self.result = result[-1] self.result_check = result_check[-1] self.grad_A = aesara.grad(self.result.sum(), self.A) self.grad_A_check = aesara.grad(self.result_check.sum(), self.A)
def test_gibbs_chain(self): rng = np.random.RandomState(utt.fetch_seed()) v_vsample = np.array( rng.binomial( 1, 0.5, size=(3, 20), ), dtype="float32", ) vsample = aesara.shared(v_vsample) trng = aesara.sandbox.rng_mrg.MRG_RandomStream(utt.fetch_seed()) def f(vsample_tm1): return ( trng.binomial(vsample_tm1.shape, n=1, p=0.3, dtype="float32") * vsample_tm1) aesara_vsamples, updates = scan( f, [], vsample, [], n_steps=10, truncate_gradient=-1, go_backwards=False, mode=self.mode_with_gpu, ) my_f = aesara.function( [], aesara_vsamples[-1], updates=updates, allow_input_downcast=True, mode=self.mode_with_gpu, ) # I leave this to tested by debugmode, this test was anyway more of # doest the graph compile kind of test my_f()
def test_memory_reuse_gpudimshuffle(self): # Test the memory pre-allocation feature in scan when one output is # the result of a GpuDimshuffle (because an optimization in # GpuDimshuffle can cause issues with the memory pre-allocation # where it falsely thinks that a pre-allocated memory region has # been used when it hasn't). def inner_fn(seq1, recurrent_out): temp = seq1 + recurrent_out.sum() output1 = temp.dimshuffle(1, 0) output2 = temp.sum() + recurrent_out return output1, output2 input1 = ftensor3() init = ftensor3() outputs_info = [None, init] out, _ = scan( inner_fn, sequences=[input1], outputs_info=outputs_info, mode=self.mode_with_gpu, ) out1 = out[0].flatten() out2 = out[1].flatten() fct = aesara.function([input1, init], [out1, out2], mode=self.mode_with_gpu) output = fct(np.ones((2, 1, 1), dtype="float32"), np.ones((1, 1, 1), dtype="float32")) expected_output = ( np.array([2, 4], dtype="float32"), np.array([3, 7], dtype="float32"), ) utt.assert_allclose(output, expected_output)
def test_gpu_memory_usage(self): # This test validates that the memory usage of the defined aesara # function is reasonnable when executed on the GPU. It checks for # a bug in which one of scan's optimization was not applied which # made the scan node compute large and unnecessary outputs which # brought memory usage on the GPU to ~12G. # Dimensionality of input and output data (not one-hot coded) n_in = 100 n_out = 100 # Number of neurons in hidden layer n_hid = 4000 # Number of minibatches mb_size = 2 # Time steps in minibatch mb_length = 200 # Define input variables xin = ftensor3(name="xin") yout = ftensor3(name="yout") # Initialize the network parameters U = aesara.shared(np.zeros((n_in, n_hid), dtype="float32"), name="W_xin_to_l1") V = aesara.shared(np.zeros((n_hid, n_hid), dtype="float32"), name="W_l1_to_l1") W = aesara.shared(np.zeros((n_hid, n_out), dtype="float32"), name="W_l1_to_l2") nparams = [U, V, W] # Build the forward pass l1_base = dot(xin, U) def scan_l(baseline, last_step): return baseline + dot(last_step, V) zero_output = aet.alloc(np.asarray(0.0, dtype="float32"), mb_size, n_hid) l1_out, _ = scan( scan_l, sequences=[l1_base], outputs_info=[zero_output], mode=self.mode_with_gpu_nodebug, ) l2_out = dot(l1_out, W) # Compute the cost and take the gradient wrt params cost = tt_sum((l2_out - yout)**2) grads = aesara.grad(cost, nparams) updates = list(zip(nparams, (n - g for n, g in zip(nparams, grads)))) # Compile the aesara function feval_backprop = aesara.function([xin, yout], cost, updates=updates, mode=self.mode_with_gpu_nodebug) # Validate that the PushOutScanOutput optimization has been applied # by checking the number of outputs of the grad Scan node in the # compiled function. nodes = feval_backprop.maker.fgraph.toposort() scan_nodes = [n for n in nodes if isinstance(n.op, Scan)] # The grad scan is always the 2nd one according to toposort. If the # optimization has been applied, it has 2 outputs, otherwise 3. grad_scan_node = scan_nodes[1] assert len(grad_scan_node.outputs) == 2, len(grad_scan_node.outputs) # Call the aesara function to ensure the absence of a memory error feval_backprop( np.zeros((mb_length, mb_size, n_in), dtype="float32"), np.zeros((mb_length, mb_size, n_out), dtype="float32"), )
def test_one_sequence_one_output_weights_gpu1(self): def f_rnn(u_t, x_tm1, W_in, W): return u_t * W_in + x_tm1 * W u = fvector("u") x0 = fscalar("x0") W_in = fscalar("win") W = fscalar("w") # The following line is needed to have the first case being used # Otherwise, it is the second that is tested. mode = self.mode_with_gpu.excluding("InputToGpuOptimizer") output, updates = scan( f_rnn, u, x0, [W_in, W], n_steps=None, truncate_gradient=-1, go_backwards=False, mode=mode, ) output = self.gpu_backend.gpu_from_host(output) f2 = aesara.function( [u, x0, W_in, W], output, updates=updates, allow_input_downcast=True, mode=self.mode_with_gpu, ) # get random initial values rng = np.random.RandomState(utt.fetch_seed()) v_u = rng.uniform(size=(4, ), low=-5.0, high=5.0) v_x0 = rng.uniform() W = rng.uniform() W_in = rng.uniform() v_u = np.asarray(v_u, dtype="float32") v_x0 = np.asarray(v_x0, dtype="float32") W = np.asarray(W, dtype="float32") W_in = np.asarray(W_in, dtype="float32") # compute the output in numpy v_out = np.zeros((4, )) v_out[0] = v_u[0] * W_in + v_x0 * W for step in range(1, 4): v_out[step] = v_u[step] * W_in + v_out[step - 1] * W aesara_values = f2(v_u, v_x0, W_in, W) utt.assert_allclose(aesara_values, v_out) # TO DEL topo = f2.maker.fgraph.toposort() scan_node = [node for node in topo if isinstance(node.op, Scan)] assert len(scan_node) == 1 scan_node = scan_node[0] topo = f2.maker.fgraph.toposort() assert (sum([ isinstance(node.op, self.gpu_backend.HostFromGpu) for node in topo ]) == 0) assert (sum([ isinstance(node.op, self.gpu_backend.GpuFromHost) for node in topo ]) == 4) scan_node = [node for node in topo if isinstance(node.op, Scan)] assert len(scan_node) == 1 scan_node = scan_node[0] scan_node_topo = scan_node.op.fn.maker.fgraph.toposort() # check that there is no gpu transfer in the inner loop. assert any([ isinstance(node.op, self.gpu_backend.GpuElemwise) for node in scan_node_topo ]) assert not any([ isinstance(node.op, self.gpu_backend.HostFromGpu) for node in scan_node_topo ]) assert not any([ isinstance(node.op, self.gpu_backend.GpuFromHost) for node in scan_node_topo ])
def scan_checkpoints( fn, sequences=None, outputs_info=None, non_sequences=None, name="checkpointscan_fn", n_steps=None, save_every_N=10, padding=True, ): """Scan function that uses less memory, but is more restrictive. In :func:`~aesara.scan`, if you compute the gradient of the output with respect to the input, you will have to store the intermediate results at each time step, which can be prohibitively huge. This function allows to do ``save_every_N`` steps of forward computations without storing the intermediate results, and to recompute them during the gradient computation. Notes ----- Current assumptions: * Every sequence has the same length. * If ``n_steps`` is specified, it has the same value as the length of any sequence. * The value of ``save_every_N`` divides the number of steps the scan will run without remainder. * Only singly-recurrent and non-recurrent outputs are used. No multiple recurrences. * Only the last timestep of any output will ever be used. Parameters ---------- fn ``fn`` is a function that describes the operations involved in one step of ``scan``. See the documentation of :func:`~aesara.scan` for more information. sequences ``sequences`` is the list of Aesara variables or dictionaries describing the sequences ``scan`` has to iterate over. All sequences must be the same length in this version of ``scan``. outputs_info ``outputs_info`` is the list of Aesara variables or dictionaries describing the initial state of the outputs computed recurrently. non_sequences ``non_sequences`` is the list of arguments that are passed to ``fn`` at each steps. One can opt to exclude variable used in ``fn`` from this list as long as they are part of the computational graph, though for clarity we encourage not to do so. n_steps ``n_steps`` is the number of steps to iterate given as an int or Aesara scalar (> 0). If any of the input sequences do not have enough elements, scan will raise an error. If n_steps is not provided, ``scan`` will figure out the amount of steps it should run given its input sequences. save_every_N ``save_every_N`` is the number of steps to go without storing the computations of ``scan`` (ie they will have to be recomputed during the gradient computation). padding If the length of the sequences is not a multiple of ``save_every_N``, the sequences will be zero padded to make this version of ``scan`` work properly, but will also result in a memory copy. It can be avoided by setting ``padding`` to False, but you need to make sure the length of the sequences is a multiple of ``save_every_N``. Returns ------- tuple Tuple of the form ``(outputs, updates)`` as in :func:`~aesara.scan`, but with a small change: It only contain the output at each ``save_every_N`` step. The time steps that are not returned by this function will be recomputed during the gradient computation (if any). See Also -------- :func:`~aesara.scan`: Looping in Aesara. """ # Standardize the format of input arguments if sequences is None: sequences = [] elif not isinstance(sequences, list): sequences = [sequences] if not isinstance(outputs_info, list): outputs_info = [outputs_info] if non_sequences is None: non_sequences = [] elif not isinstance(non_sequences, list): non_sequences = [non_sequences] # Check that outputs_info has no taps: for element in outputs_info: if isinstance(element, dict) and "taps" in element: raise RuntimeError("scan_checkpoints doesn't work with taps.") # Determine how many steps the original scan would run if n_steps is None: n_steps = sequences[0].shape[0] # Compute the number of steps of the outer scan o_n_steps = at.cast(ceil(n_steps / save_every_N), "int64") # Compute the number of steps of the inner scan i_n_steps = save_every_N * at.ones((o_n_steps, ), "int64") mod = n_steps % save_every_N last_n_steps = at.switch(eq(mod, 0), save_every_N, mod) i_n_steps = set_subtensor(i_n_steps[-1], last_n_steps) # Pad the sequences if needed if padding: # Since padding could be an empty tensor, Join returns a view of s. join = Join(view=0) for i, s in enumerate(sequences): n = s.shape[0] % save_every_N z = at.zeros((n, s.shape[1:]), dtype=s.dtype) sequences[i] = join(0, [s, z]) # Establish the input variables of the outer scan o_sequences = [ s.reshape( [s.shape[0] / save_every_N, save_every_N] + [s.shape[i] for i in range(1, s.ndim)], s.ndim + 1, ) for s in sequences ] o_sequences.append(i_n_steps) new_nitsots = [i for i in outputs_info if i is None] o_nonsequences = non_sequences def outer_step(*args): # Separate the received arguments into their respective (seq, outputs # from previous iterations, nonseqs) categories i_sequences = list(args[:len(o_sequences)]) i_prev_outputs = list(args[len(o_sequences):-len(o_nonsequences)]) i_non_sequences = list(args[-len(o_nonsequences):]) i_outputs_infos = i_prev_outputs + [ None, ] * len(new_nitsots) # Call the user-provided function with the proper arguments results, updates = scan( fn=fn, sequences=i_sequences[:-1], outputs_info=i_outputs_infos, non_sequences=i_non_sequences, name=name + "_inner", n_steps=i_sequences[-1], ) if not isinstance(results, list): results = [results] # Keep only the last timestep of every output but keep all the updates if not isinstance(results, list): return results[-1], updates else: return [r[-1] for r in results], updates results, updates = scan( fn=outer_step, sequences=o_sequences, outputs_info=outputs_info, non_sequences=o_nonsequences, name=name + "_outer", n_steps=o_n_steps, allow_gc=True, ) return results, updates
def test_jax_scan_multiple_output(): """Test a scan implementation of a SEIR model. SEIR model definition: S[t+1] = S[t] - B[t] E[t+1] = E[t] +B[t] - C[t] I[t+1] = I[t+1] + C[t] - D[t] B[t] ~ Binom(S[t], beta) C[t] ~ Binom(E[t], gamma) D[t] ~ Binom(I[t], delta) """ def binomln(n, k): return gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1) def binom_log_prob(n, p, value): return binomln(n, value) + value * log(p) + (n - value) * log(1 - p) # sequences aet_C = ivector("C_t") aet_D = ivector("D_t") # outputs_info (initial conditions) st0 = lscalar("s_t0") et0 = lscalar("e_t0") it0 = lscalar("i_t0") logp_c = scalar("logp_c") logp_d = scalar("logp_d") # non_sequences beta = scalar("beta") gamma = scalar("gamma") delta = scalar("delta") # TODO: Use random streams when their JAX conversions are implemented. # trng = aesara.tensor.random.RandomStream(1234) def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta): # bt0 = trng.binomial(n=st0, p=beta) bt0 = st0 * beta bt0 = bt0.astype(st0.dtype) logp_c1 = binom_log_prob(et0, gamma, ct0).astype(logp_c.dtype) logp_d1 = binom_log_prob(it0, delta, dt0).astype(logp_d.dtype) st1 = st0 - bt0 et1 = et0 + bt0 - ct0 it1 = it0 + ct0 - dt0 return st1, et1, it1, logp_c1, logp_d1 (st, et, it, logp_c_all, logp_d_all), _ = scan( fn=seir_one_step, sequences=[aet_C, aet_D], outputs_info=[st0, et0, it0, logp_c, logp_d], non_sequences=[beta, gamma, delta], ) st.name = "S_t" et.name = "E_t" it.name = "I_t" logp_c_all.name = "C_t_logp" logp_d_all.name = "D_t_logp" out_fg = FunctionGraph( [aet_C, aet_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta], [st, et, it, logp_c_all, logp_d_all], ) s0, e0, i0 = 100, 50, 25 logp_c0 = np.array(0.0, dtype=config.floatX) logp_d0 = np.array(0.0, dtype=config.floatX) beta_val, gamma_val, delta_val = [ np.array(val, dtype=config.floatX) for val in [0.277792, 0.135330, 0.108753] ] C = np.array([3, 5, 8, 13, 21, 26, 10, 3], dtype=np.int32) D = np.array([1, 2, 3, 7, 9, 11, 5, 1], dtype=np.int32) test_input_vals = [ C, D, s0, e0, i0, logp_c0, logp_d0, beta_val, gamma_val, delta_val, ] compare_jax_and_py(out_fg, test_input_vals)