def local_abstract_batch_norm_train_grad(node): if not isinstance(node.op, AbstractBatchNormTrainGrad): return None x, dy, scale, x_mean, x_invstd, epsilon = node.inputs axes = node.op.axes if min(axes) < 0 or max(axes) > x.ndim: return None if (not isinstance(x.type, TensorType) or not isinstance(dy.type, TensorType) or not isinstance(scale.type, TensorType) or not isinstance(x_mean.type, TensorType) or not isinstance(x_invstd.type, TensorType) or not isinstance(epsilon.type, TensorType)): return None x_diff = x - x_mean mean_dy_x_diff = tt.mean(dy * x_diff, axis=axes, keepdims=True) c = (dy * x_invstd) - x_diff * (mean_dy_x_diff * (x_invstd**3)) g_wrt_inputs = scale * (c - tt.mean(c, axis=axes, keepdims=True)) g_wrt_scale = tt.sum(dy * x_invstd * x_diff, axis=axes, keepdims=True) g_wrt_bias = tt.sum(dy, axis=axes, keepdims=True) results = [g_wrt_inputs, g_wrt_scale, g_wrt_bias] results = [ tt.patternbroadcast(r, r_orig.broadcastable) for (r, r_orig) in zip(results, node.outputs) ] for var in aesara.gof.graph.variables(node.inputs, results): if var not in node.inputs: copy_stack_trace(node.outputs[0], var) return results
def local_inv_1_plus_exp(node): """ 1/(1+exp(x)) -> sigm(-x) """ # this optimization should be done for numerical stability # so we don't care to check client counts if node.op == tensor.inv: inv_arg = node.inputs[0] if inv_arg.owner and inv_arg.owner.op == tensor.add: scalars, scalar_inputs, nonconsts = opt.scalarconsts_rest( inv_arg.owner.inputs, only_process_constants=True) # scalar_inputs are potentially dimshuffled and fill'd scalars if len(nonconsts) == 1: if nonconsts[0].owner and nonconsts[0].owner.op == tensor.exp: if scalars and np.allclose(np.sum(scalars), 1): out = opt._fill_chain( sigmoid(tensor.neg(nonconsts[0].owner.inputs[0])), scalar_inputs, ) # keep combined stack traces of # exp(x): nonconsts[0], # 1 + exp(x): inv_arg, # 1 / (1 + exp(x)): node.outputs[0] copy_stack_trace( [nonconsts[0], inv_arg, node.outputs[0]], out) return out
def local_abstractconv_gemm(node): # If aesara.config.blas.ldflags is empty, Aesara will use # a NumPy C implementation of [sd]gemm_. if aesara.config.cxx == "" or node.inputs[0].dtype == "float16": return if not isinstance(node.op, AbstractConv2d): return None img, kern = node.inputs if not isinstance(img.type, TensorType) or not isinstance( kern.type, TensorType): return None # need to flip the kernel if necessary if node.op.filter_flip: flip = (slice(None), ) * (kern.ndim - 2) + (slice(None, None, -1), ) * 2 kern = kern[flip] rval = CorrMM( border_mode=node.op.border_mode, subsample=node.op.subsample, filter_dilation=node.op.filter_dilation, num_groups=node.op.num_groups, unshared=node.op.unshared, )(img, kern) copy_stack_trace(node.outputs[0], rval) return [rval]
def local_sigm_times_exp(node): """ exp(x) * sigm(-x) -> sigm(x) exp(-x) * sigm(x) -> sigm(-x) todo: add stack traces to the intermediate variables """ # Bail early if it is not a multiplication. if node.op != tensor.mul: return None # Obtain tree of multiplications starting at this node. mul_tree = parse_mul_tree(node.outputs[0]) # Perform core optimization. did_something = perform_sigm_times_exp(mul_tree) if not did_something: # No change. return None # The optimization may have introduced multiplications by 1 in the tree: # get rid of them. mul_tree = simplify_mul(mul_tree) # Recompute final output based on the updated tree. out = compute_mul(mul_tree) # keep the stack trace copy_stack_trace(node.outputs[0], out) return [out]
def local_conv2d_cpu(node): if not isinstance(node.op, AbstractConv2d) or node.inputs[0].dtype == "float16": return None img, kern = node.inputs if not isinstance(img.type, TensorType) or not isinstance( kern.type, TensorType): return None if node.op.border_mode not in ["full", "valid"]: return None if not node.op.filter_flip: # Not tested yet return None if node.op.num_groups > 1 or node.op.unshared: return None if node.op.filter_dilation != (1, 1): return None rval = conv2d( img, kern, node.op.imshp, node.op.kshp, border_mode=node.op.border_mode, subsample=node.op.subsample, ) copy_stack_trace(node.outputs[0], rval) return [rval]
def local_abstractconv3d_gradweight_gemm(node): # If aesara.config.blas.ldflags is empty, Aesara will use # a NumPy C implementation of [sd]gemm_. if aesara.config.cxx == "" or node.inputs[0].dtype == "float16": return if not isinstance(node.op, AbstractConv3d_gradWeights): return None img, topgrad, shape = node.inputs if not isinstance(img.type, TensorType) or not isinstance( topgrad.type, TensorType): return None rval = Corr3dMMGradWeights( border_mode=node.op.border_mode, subsample=node.op.subsample, filter_dilation=node.op.filter_dilation, num_groups=node.op.num_groups, )(img, topgrad, shape) copy_stack_trace(node.outputs[0], rval) # need to flip the kernel if necessary if node.op.filter_flip: rval = rval[:, :, ::-1, ::-1, ::-1] rval = aesara.tensor.patternbroadcast(rval, node.outputs[0].broadcastable) copy_stack_trace(node.outputs[0], rval) return [rval]
def local_abstract_batch_norm_inference(node): if not isinstance(node.op, AbstractBatchNormInference): return None x, scale, bias, estimated_mean, estimated_variance, epsilon = node.inputs if (not isinstance(x.type, TensorType) or not isinstance(scale.type, TensorType) or not isinstance(bias.type, TensorType) or not isinstance(estimated_mean.type, TensorType) or not isinstance(estimated_variance.type, TensorType) or not isinstance(epsilon.type, TensorType)): return None # The epsilon should not upcast the dtype. if estimated_variance.dtype == "float32" and epsilon.dtype == "float64": epsilon = epsilon.astype("float32") result = (x - estimated_mean) * ( scale / tt.sqrt(estimated_variance + epsilon)) + bias result = tt.patternbroadcast(result, node.outputs[0].broadcastable) for var in aesara.gof.graph.variables(node.inputs, [result]): if var not in node.inputs: copy_stack_trace(node.outputs[0], var) return [result]
def local_ultra_fast_sigmoid(node): """ When enabled, change all sigmoid to ultra_fast_sigmoid. For example do mode.including('local_ultra_fast_sigmoid') or use the Aesara flag optimizer_including=local_ultra_fast_sigmoid. This speeds up the sigmoid op by using an approximation. This is done after the stabilization and specialize phases to avoid interacting with them. """ if isinstance(node.op, tensor.Elemwise) and node.op.scalar_op == scalar_sigmoid: out = ultra_fast_sigmoid(node.inputs[0]) copy_stack_trace(node.outputs[0], out) def values_eq_approx_remove_low_prec(a, b): # atol is found by trial/error. # Other test could fail without good reason. return tensor.TensorType.values_eq_approx(a, b, atol=0.02) # Let DebugMode know that there this opt approx the values. out.tag.values_eq_approx = values_eq_approx_remove_low_prec return [out]
def local_inplace_sparse_block_outer(node): """ SparseBlockOuter(inplace=False) -> SparseBlockOuter(inplace=True) """ if isinstance(node.op, SparseBlockOuter) and not node.op.inplace: new_node = sparse_block_outer_inplace(*node.inputs) copy_stack_trace(node.outputs[0], new_node) return [new_node] return False
def local_inplace_DiagonalSubtensor(node): """Also work for IncDiagonalSubtensor.""" if (isinstance(node.op, (DiagonalSubtensor, IncDiagonalSubtensor)) and not node.op.inplace): new_op = node.op.__class__(inplace=True) new_node = new_op(*node.inputs) copy_stack_trace(node.outputs[0], new_node) return [new_node] return False
def local_abstract_batch_norm_train(node): if not isinstance(node.op, AbstractBatchNormTrain): return None x, scale, bias, epsilon, running_average_factor = node.inputs[:5] axes = node.op.axes if min(axes) < 0 or max(axes) > x.ndim: return None if (not isinstance(x.type, TensorType) or not isinstance(scale.type, TensorType) or not isinstance(bias.type, TensorType) or not isinstance(epsilon.type, TensorType) or not isinstance(running_average_factor.type, TensorType)): return None # optional running_mean and running_var if len(node.inputs) > 5 and not isinstance(node.inputs[5].type, TensorType): return None if len(node.inputs) > 6 and not isinstance(node.inputs[6].type, TensorType): return None mean = x.mean(axes, keepdims=True) var = x.var(axes, keepdims=True) # The epsilon should not upcast the dtype. if var.dtype == "float32" and epsilon.dtype == "float64": epsilon = epsilon.astype("float32") invstd = tt.inv(tt.sqrt(var + epsilon)) out = (x - mean) * (scale * invstd) + bias results = [out, mean, invstd] if len(node.inputs) > 5: running_mean = node.inputs[5] running_mean = (running_mean * (1.0 - running_average_factor) + mean * running_average_factor) results.append(running_mean) if len(node.inputs) > 6: m = tt.cast( tt.prod(x.shape) / tt.prod(scale.shape), aesara.config.floatX) running_var = node.inputs[6] running_var = (running_var * (1.0 - running_average_factor) + (m / (m - 1)) * var * running_average_factor) results.append(running_var) results = [ tt.patternbroadcast(r, r_orig.broadcastable) for (r, r_orig) in zip(results, node.outputs) ] for var in aesara.gof.graph.variables(node.inputs, results): if var not in node.inputs: copy_stack_trace(node.outputs[0], var) return results
def local_hard_sigmoid(node): if isinstance(node.op, tensor.Elemwise) and node.op.scalar_op == scalar_sigmoid: out = hard_sigmoid(node.inputs[0]) copy_stack_trace(node.outputs[0], out) def values_eq_approx_remove_low_prec(a, b): # atol is found by trial/error. # Other test could fail without good reason. return tensor.TensorType.values_eq_approx(a, b, atol=0.1) # Let DebugMode know that there this opt approx the values. out.tag.values_eq_approx = values_eq_approx_remove_low_prec return [out]
def local_max_and_argmax(node): """ If we don't use the argmax, change it to a max only. """ if isinstance(node.op, tt.MaxAndArgmax): axis = node.op.get_params(node) if len(node.outputs[1].clients) == 0: new = tt.Max(axis)(node.inputs[0]) copy_stack_trace(node.outputs[0], new) return [new, None] if len(node.outputs[0].clients) == 0: new = tt.Argmax(axis)(node.inputs[0]) copy_stack_trace(node.outputs[0], new) return [None, new]
def local_1msigmoid(node): """ 1-sigm(x) -> sigm(-x) """ if node.op == tensor.sub: sub_l, sub_r = node.inputs if len(sub_r.clients) > 1: return # graph is using both sigm and 1-sigm if sub_r.owner and sub_r.owner.op == sigmoid: try: val_l = opt.get_scalar_constant_value(sub_l) except tensor.NotScalarConstantError: return if np.allclose(np.sum(val_l), 1): out = sigmoid(-sub_r.owner.inputs[0]) copy_stack_trace([sub_r, node.outputs[0]], out) return [out]
def local_exp_over_1_plus_exp(node): """ exp(x)/(1+exp(x)) -> sigm(x) c/(1+exp(x)) -> c*sigm(-x) """ # this optimization should be done for numerical stability # so we don't care to check client counts if node.op == tensor.true_div: # find all the exp() terms in the numerator num, denom = node.inputs num_exp_x, num_rest, num_neg = partition_num_or_denom(num, is_exp) denom_1pexp, denom_rest, denom_neg = partition_num_or_denom( denom, is_1pexp) sigmoids = [] for t in denom_1pexp: if t in num_exp_x: # case: exp(x) /(1+exp(x)) sigmoids.append(sigmoid(t)) del num_exp_x[num_exp_x.index(t)] else: # case: 1/(1+exp(x)) sigmoids.append(sigmoid(-t)) copy_stack_trace(node.outputs[0], sigmoids[-1]) if not sigmoids: # we didn't find any. abort return # put the new numerator together new_num = sigmoids + [tensor.exp(t) for t in num_exp_x] + num_rest if len(new_num) == 1: new_num = new_num[0] else: new_num = tensor.mul(*new_num) if num_neg ^ denom_neg: new_num = -new_num copy_stack_trace(num, new_num) if len(denom_rest) == 0: return [new_num] elif len(denom_rest) == 1: out = new_num / denom_rest[0] else: out = new_num / tensor.mul(*denom_rest) copy_stack_trace(node.outputs[0], out) return [out]
def local_max_to_min(node): """ Change -(max(-x)) to min. This is tested in tensor/tests/test_basic.py:test_min_max. Notes ----- We don't need an opt that will do the reverse as by default the interface put only MaxAndArgmax into the graph. """ if node.op == tt.neg and node.inputs[0].owner: max = node.inputs[0] if (max.owner and isinstance(max.owner.op, CAReduce) and max.owner.op.scalar_op == scal.maximum): neg = max.owner.inputs[0] if neg.owner and neg.owner.op == tt.neg: new = tt.Min(max.owner.op.axis)(neg.owner.inputs[0]) return [copy_stack_trace(node.outputs[0], new)] return False
def local_opt(node): if type(node.op) in OP: # Either one of our inputs is on the gpu or # all of our clients are on the gpu replace = False # TODO: Maybe set context_name with infer_context_name()? context_name = None # We replace if any input is a host_from_gpu for i in node.inputs: if i.owner and i.owner.op == host_from_gpu and move_to_gpu(i): context_name = i.owner.inputs[0].type.context_name replace = True break if not replace: # We replace if *all* clients are on the GPU clients = [c for o in node.outputs for c in o.clients] replace = len(clients) != 0 for c, idx in clients: if c == "output" or not isinstance(c.op, GpuFromHost): replace = False # TODO: check that the clients want the same context? if replace: # All clients are GpuFromHost and we have at least one context_name = clients[0][0].op.context_name # Check if we should replace if ( not replace or (cuda_only and get_context(context_name).kind != b"cuda") or any(["complex" in getattr(i, "dtype", "") for i in node.inputs]) ): return False # tag the inputs with the context in case # the context was derived from the outputs for i in node.inputs: i.tag.context_name = context_name new_op = maker(node.op, context_name, node.inputs, node.outputs) # This is needed as sometimes new_op inherits from OP. if new_op and new_op != node.op: if isinstance(new_op, Op): new_outputs = new_op(*node.inputs, return_list=True) to_cpu_fn = safe_to_cpu elif isinstance(new_op, (tuple, list)): new_outputs = new_op to_cpu_fn = safe_to_cpu else: # suppose it is a variable on the GPU new_outputs = [new_op] def to_cpu_fn(x): return x.transfer("cpu") # copy stack traces onto gpu outputs # also copy the stack traces onto HostFromGpu outputs on_cpu = [] for old_output, new_output in zip(node.outputs, new_outputs): copy_stack_trace(old_output, new_output) cpu = to_cpu_fn(new_output) on_cpu.append(cpu) copy_stack_trace(old_output, cpu) return on_cpu return False
def local_conv2d_gradweight_cpu(node): if (not isinstance(node.op, AbstractConv2d_gradWeights) or node.inputs[0].dtype == "float16"): return None img, topgrad, shape = node.inputs if not isinstance(img.type, TensorType) or not isinstance( topgrad.type, TensorType): return None if node.op.border_mode not in ["full", "valid"]: return None if not node.op.filter_flip: # Not tested yet return if node.op.num_groups > 1 or node.op.unshared: return None if node.op.border_mode == "valid" and (node.op.subsample != (1, 1)): return None dx, dy = node.op.subsample if dx not in (1, 2) or dy not in (1, 2): # Not implemented in the gradient of ConvOp return None if node.op.imshp is None: op_imshp = (None, None, None, None) else: op_imshp = node.op.imshp if node.op.kshp is None: op_kshp = (None, None, None, None) else: op_kshp = node.op.kshp if None in op_imshp or None in op_kshp: if (dx, dy) != (1, 1): # We cannot infer the shapes return None # Determine gradient on kernels assert len(op_imshp) == 4 and len(op_kshp) == 4 outshp = get_conv_output_shape( op_imshp, op_kshp, node.op.border_mode, node.op.subsample, node.op.filter_dilation, )[2:] fulloutshp = get_conv_output_shape(op_imshp, op_kshp, node.op.border_mode, (1, 1))[2:] newimg = img.dimshuffle((1, 0, 2, 3)) newtopgrad = topgrad.dimshuffle((1, 0, 2, 3)) if node.op.border_mode == "valid": (img, filters) = (newimg, newtopgrad) kshp_logical = fulloutshp kshp_logical_top_aligned = False imshp_logical = None (bsize, nkern) = (op_imshp[1], op_kshp[0]) imshp = (op_imshp[0], op_imshp[2], op_imshp[3]) kshp = outshp elif node.op.border_mode == "full": (img, filters) = (newtopgrad, newimg) kshp_logical = None kshp_logical_top_aligned = True imshp_logical = (op_imshp[0], fulloutshp[0], fulloutshp[1]) (bsize, nkern) = (op_kshp[0], op_imshp[1]) imshp = (op_imshp[0], outshp[0], outshp[1]) kshp = op_imshp[2:] else: raise NotImplementedError( "Only [full,valid] modes are currently supported.") # Flip the kernels filters = filters[:, :, ::-1, ::-1] dw = ConvOp( imshp, kshp, nkern, bsize, 1, 1, output_mode="valid", unroll_batch=None, unroll_kern=None, unroll_patch=None, imshp_logical=imshp_logical, kshp_logical=kshp_logical, kshp_logical_top_aligned=kshp_logical_top_aligned, direction_hint="bprop weights", ) res = dw(img, filters) copy_stack_trace(node.outputs[0], res) if node.op.border_mode == "valid": res = res.dimshuffle((1, 0, 2, 3)) res = res[:, :, ::-1, ::-1] res = aesara.tensor.patternbroadcast(res, node.outputs[0].broadcastable) copy_stack_trace(node.outputs[0], res) return [res]
def local_conv2d_gradinputs_cpu(node): if (not isinstance(node.op, AbstractConv2d_gradInputs) or node.inputs[0].dtype == "float16"): return None kern, topgrad, shape = node.inputs if not isinstance(kern.type, TensorType) or not isinstance( topgrad.type, TensorType): return None if node.op.border_mode not in ["full", "valid"]: return None if not node.op.filter_flip: # Not tested yet return None if node.op.num_groups > 1 or node.op.unshared: return None # Conv 3d implementation, needed when subsample > 2 if node.op.border_mode == "valid" and node.op.subsample != (1, 1): # The op don't support that anymore. return False # Conv2d Implementation dx, dy = node.op.subsample if dx not in (1, 2) or dy not in (1, 2): # Not implemented in the gradient of ConvOp return None if node.op.imshp is None: op_imshp = (None, None, None, None) else: op_imshp = node.op.imshp if node.op.kshp is None: op_kshp = (None, None, None, None) else: op_kshp = node.op.kshp if None in op_imshp or None in op_kshp: if (dx, dy) != (1, 1): return None mode = "valid" if not node.op.border_mode == "full": mode = "full" filters = kern.dimshuffle((1, 0, 2, 3)) filters = filters[:, :, ::-1, ::-1] outshp = get_conv_output_shape( op_imshp, op_kshp, node.op.border_mode, node.op.subsample, node.op.filter_dilation, )[2:] fulloutshp = get_conv_output_shape(op_imshp, op_kshp, node.op.border_mode, (1, 1))[2:] nkern = op_imshp[1] imshp = (op_kshp[0], outshp[0], outshp[1]) imshp_logical = (op_kshp[0], fulloutshp[0], fulloutshp[1]) din = ConvOp( imshp, op_kshp[2:], nkern, op_imshp[0], 1, 1, output_mode=mode, unroll_batch=None, unroll_kern=None, unroll_patch=None, imshp_logical=imshp_logical, kshp_logical=None, version=-1, direction_hint="bprop inputs", ) din = din(topgrad, filters) copy_stack_trace(node.outputs[0], din) din = aesara.tensor.patternbroadcast(din, node.outputs[0].broadcastable) copy_stack_trace(node.outputs[0], din) return [din]