def transposed(*args): res, cts_out = split_list(args, [num_res]) primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals] cts_in = ad.backward_pass(jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, cts_out) _, cts_in = split_list(cts_in, [num_res]) return _map(ad.instantiate_zeros_aval, primal_avals, cts_in)
def transposed(*args): in_primals, out_cts = tree_unflatten(treedef, args) in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else pe.PartialVal.known(x) for x in in_primals] primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ())) tangent_jaxpr, _, consts = pe.trace_to_jaxpr(primal_fun, in_pvals, False) dummy_args = [ad.UndefinedPrimal(v.aval) for v in tangent_jaxpr.invars] in_cts_ = ad.backward_pass(tangent_jaxpr, reduce_axes, False, consts, dummy_args, out_cts) in_cts, cell.treedef = tree_flatten(in_cts_) return in_cts
def read_primal(v): if type(v) is Literal: return v.val else: return primal_env.get(v, ad.UndefinedPrimal(v.aval))
def _crown_linear_op(lin_primitive, out_bound, *invals, **kwargs): """Backward propagation of LinearBounds through the primitive `lin_primitive`. This is achieved by piggybacking on the auto-differentiation code and relying on the fact that the backward propagation through linear layer is the transpose of the linear operation, which is the same as the operation done for gradient backpropagation. Args: lin_primitive: Jax primitive representing a bilinear operation, to go backward through. out_bound: CrownBackwardBound, linear function of the network outputs with regards to the output activation of this layer. *invals: input of the bound propagation in the forward pass **kwargs: Dict with the parameters of the linear operation. Returns: new_in_args: List of CrownBackwardBound """ backward_primitive = ad.get_primitive_transpose(lin_primitive) to_backprop = jnp.concatenate( [out_bound.lower_lin.lin_coeffs, out_bound.upper_lin.lin_coeffs], axis=1) nb_coeff_dim = (out_bound.lower_lin.lin_coeffs.ndim - out_bound.lower_lin.offset.ndim) nb_output_dim = out_bound.lower_lin.offset.ndim - 1 cts_in = to_backprop unwrapped_invals = [] for inval in invals: if isinstance(inval, ibp.IntervalBound): # Create a fake input that would have matched the artificial construct # we defined as cts_in shape = cts_in.shape[:-nb_coeff_dim] + inval.shape[1:] unwrapped_invals.append(ad.UndefinedPrimal(jnp.zeros(shape))) elif isinstance(inval, jnp.ndarray): unwrapped_invals.append(inval) else: raise ValueError('Unexpected input for the crown-ibp' f'primitive for {lin_primitive}.') vmap_outaxes = tuple(1 if isinstance(arg, ibp.IntervalBound) else None for arg in invals) vmap_inaxes = (1, ) + vmap_outaxes backward_op = functools.partial(backward_primitive, **kwargs) vmap_backward_op = backward_op for _ in range(nb_output_dim): # Vmap over all the dimensions that we need to pass through. vmap_backward_op = jax.vmap(vmap_backward_op, in_axes=vmap_inaxes, out_axes=vmap_outaxes) cts_out = vmap_backward_op(cts_in, *unwrapped_invals) new_in_args = [] for arg in cts_out: if arg is None: # This correspond to the input that was a constant, we don't want # to propagate anything there. new_in_args.append(arg) else: lower_lin_coeffs, upper_lin_coeffs = jnp.split(arg, 2, axis=1) new_in_args.append( CrownBackwardBound( LinearExpression(lower_lin_coeffs, out_bound.lower_lin.offset), LinearExpression(upper_lin_coeffs, out_bound.upper_lin.offset))) return new_in_args
def read_primal(v): if type(v) is core.Literal: raise NotImplementedError # TODO else: return primal_env.get(v, ad.UndefinedPrimal(v.aval))