def local_abstract_batch_norm_train(fgraph, node): if not isinstance(node.op, AbstractBatchNormTrain): return None x, scale, bias, epsilon, running_average_factor = node.inputs[:5] axes = node.op.axes if min(axes) < 0 or max(axes) > x.ndim: return None if ( not isinstance(x.type, TensorType) or not isinstance(scale.type, TensorType) or not isinstance(bias.type, TensorType) or not isinstance(epsilon.type, TensorType) or not isinstance(running_average_factor.type, TensorType) ): return None # optional running_mean and running_var if len(node.inputs) > 5 and not isinstance(node.inputs[5].type, TensorType): return None if len(node.inputs) > 6 and not isinstance(node.inputs[6].type, TensorType): return None mean = x.mean(axes, keepdims=True) var = x.var(axes, keepdims=True) # The epsilon should not upcast the dtype. if var.dtype == "float32" and epsilon.dtype == "float64": epsilon = epsilon.astype("float32") invstd = inv(sqrt(var + epsilon)) out = (x - mean) * (scale * invstd) + bias results = [out, mean, invstd] if len(node.inputs) > 5: running_mean = node.inputs[5] running_mean = ( running_mean * (1.0 - running_average_factor) + mean * running_average_factor ) results.append(running_mean) if len(node.inputs) > 6: m = aet.cast(prod(x.shape) / prod(scale.shape), config.floatX) running_var = node.inputs[6] running_var = ( running_var * (1.0 - running_average_factor) + (m / (m - 1)) * var * running_average_factor ) results.append(running_var) results = [ aet.patternbroadcast(r, r_orig.broadcastable) for (r, r_orig) in zip(results, node.outputs) ] for var in aesara.graph.basic.vars_between(node.inputs, results): if var not in node.inputs: copy_stack_trace(node.outputs[0], var) return results
def test_jax_CAReduce(): a_at = vector("a") a_at.tag.test_value = np.r_[1, 2, 3].astype(config.floatX) x = at_sum(a_at, axis=None) x_fg = FunctionGraph([a_at], [x]) compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(config.floatX)]) a_at = matrix("a") a_at.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX) x = at_sum(a_at, axis=0) x_fg = FunctionGraph([a_at], [x]) compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) x = at_sum(a_at, axis=1) x_fg = FunctionGraph([a_at], [x]) compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) a_at = matrix("a") a_at.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX) x = prod(a_at, axis=0) x_fg = FunctionGraph([a_at], [x]) compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) x = at_all(a_at) x_fg = FunctionGraph([a_at], [x]) compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
def infer_shape(self, fgraph, node, i0_shapes): ret = fgraph.shape_feature.default_infer_shape(fgraph, node, i0_shapes) if self.axis is not None: self_axis = self.axis ndim = len(i0_shapes[0]) if self_axis < 0: self_axis += ndim if self_axis < 0 or self_axis >= ndim: raise RuntimeError( f"Unique axis `{self.axis}` is outside of input ndim = {ndim}." ) ret[0] = tuple([ fgraph.shape_feature.shape_ir(i, node.outputs[0]) for i in range(ndim) ]) if self.return_inverse: if self.axis is None: shape = (prod(i0_shapes[0]), ) else: shape = (i0_shapes[0][self_axis], ) if self.return_index: ret[2] = shape return ret ret[1] = shape return ret return ret
def pad_dims(input, leftdims, rightdims): """Reshapes the input to a (leftdims + rightdims) tensor This helper function is used to convert pooling inputs with arbitrary non-pooling dimensions to the correct number of dimensions for the GPU pooling ops. This reduces or expands the number of dimensions of the input to exactly `leftdims`, by adding extra dimensions on the left or by combining some existing dimensions on the left of the input. Use `unpad_dims` to reshape back to the original dimensions. Examples -------- Given input of shape (3, 5, 7), ``pad_dims(input, 2, 2)`` adds a singleton dimension and reshapes to (1, 3, 5, 7). Given that output from pad_dims, ``unpad_dims(output, input, 2, 2)`` reshapes back to (3, 5, 7). Given input of shape (3, 5, 7, 9), ``pad_dims(input, 2, 2)`` does not reshape and returns output with shape (3, 5, 7, 9). Given input of shape (3, 5, 7, 9, 11), ``pad_dims(input, 2, 2)`` combines the first two dimensions and reshapes to (15, 7, 9, 11). Given input of shape (3, 5, 7, 9), ``pad_dims(input, 2, 3)`` adds a singleton dimension and reshapes to (1, 3, 5, 7, 9). """ assert input.ndim >= rightdims if input.ndim == (leftdims + rightdims): return input # extract image dimensions img_shape = input.shape[-rightdims:] non_pool_ndim = input.ndim - rightdims if non_pool_ndim < leftdims: # too few dimensions, pad on the left dummy_dims = as_tensor([1] * (leftdims - non_pool_ndim)) new_shape = join(0, dummy_dims, input.shape[:non_pool_ndim], img_shape) else: # too many dimensions, combine the leading dimensions batched_ndim = non_pool_ndim - leftdims + 1 batch_size = prod(input.shape[:batched_ndim]) # convert to a vector for join batch_size = shape_padright(batch_size, 1) new_shape = join( 0, batch_size, input.shape[batched_ndim:non_pool_ndim], img_shape ) # store in the required shape new_shape = cast(new_shape, "int64") input_ND = GpuReshape(leftdims + rightdims)(input, new_shape) return input_ND
def local_det_chol(fgraph, node): """ If we have det(X) and there is already an L=cholesky(X) floating around, then we can use prod(diag(L)) to get the determinant. """ if node.op == det: (x, ) = node.inputs for (cl, xpos) in fgraph.clients[x]: if isinstance(cl.op, Cholesky): L = cl.outputs[0] return [prod(aet.extract_diag(L)**2)]
def normal( self, size, avg=0.0, std=1.0, ndim=None, dtype=None, nstreams=None, truncate=False, **kwargs, ): """ Sample a tensor of values from a normal distribution. Parameters ---------- size : int_vector_like Array dimensions for the output tensor. avg : float_like, optional The mean value for the truncated normal to sample from (defaults to 0.0). std : float_like, optional The standard deviation for the truncated normal to sample from (defaults to 1.0). truncate : bool, optional Truncates the normal distribution at 2 standard deviations if True (defaults to False). When this flag is set, the standard deviation of the result will be less than the one specified. ndim : int, optional The number of dimensions for the output tensor (defaults to None). This argument is necessary if the size argument is ambiguous on the number of dimensions. dtype : str, optional The data-type for the output tensor. If not specified, the dtype is inferred from avg and std, but it is at least as precise as floatX. kwargs Other keyword arguments for random number generation (see uniform). Returns ------- samples : TensorVariable A Aesara tensor of samples randomly drawn from a normal distribution. """ size = _check_size(size) avg = undefined_grad(as_tensor_variable(avg)) std = undefined_grad(as_tensor_variable(std)) if dtype is None: dtype = aes.upcast(config.floatX, avg.dtype, std.dtype) avg = at.cast(avg, dtype=dtype) std = at.cast(std, dtype=dtype) # generate even number of uniform samples # Do manual constant folding to lower optiimizer work. if isinstance(size, Constant): n_odd_samples = size.prod(dtype="int64") else: n_odd_samples = prod(size, dtype="int64") n_even_samples = n_odd_samples + n_odd_samples % 2 uniform = self.uniform( (n_even_samples, ), low=0.0, high=1.0, ndim=1, dtype=dtype, nstreams=nstreams, **kwargs, ) # box-muller transform u1 = uniform[:n_even_samples // 2] u2 = uniform[n_even_samples // 2:] r = sqrt(-2.0 * log(u1)) theta = np.array(2.0 * np.pi, dtype=dtype) * u2 cos_theta, sin_theta = cos(theta), sin(theta) z0 = r * cos_theta z1 = r * sin_theta if truncate: # use valid samples to_fix0 = (z0 < -2.0) | (z0 > 2.0) to_fix1 = (z1 < -2.0) | (z1 > 2.0) z0_valid = z0[at.nonzero(~to_fix0)] z1_valid = z1[at.nonzero(~to_fix1)] # re-sample invalid samples to_fix0 = at.nonzero(to_fix0)[0] to_fix1 = at.nonzero(to_fix1)[0] n_fix_samples = to_fix0.size + to_fix1.size lower = at.constant(1.0 / np.e**2, dtype=dtype) u_fix = self.uniform( (n_fix_samples, ), low=lower, high=1.0, ndim=1, dtype=dtype, nstreams=nstreams, **kwargs, ) r_fix = sqrt(-2.0 * log(u_fix)) z0_fixed = r_fix[:to_fix0.size] * cos_theta[to_fix0] z1_fixed = r_fix[to_fix0.size:] * sin_theta[to_fix1] # pack everything together to a useful result norm_samples = at.join(0, z0_valid, z0_fixed, z1_valid, z1_fixed) else: norm_samples = at.join(0, z0, z1) if isinstance(n_odd_samples, Variable): samples = norm_samples[:n_odd_samples] elif n_odd_samples % 2 == 1: samples = norm_samples[:-1] else: samples = norm_samples samples = reshape(samples, newshape=size, ndim=ndim) samples *= std samples += avg return samples
def infer_shape(self, fgraph, node, shapes): if self.axis is None: return [(prod(shapes[0]), )] # Flatten return shapes