Exemple #1
0
 def transposed(*args):
     res, cts_out = split_list(args, [num_res])
     primals = res + [ad.UndefinedPrimal(aval) for aval in primal_avals]
     cts_in = ad.backward_pass(jaxpr.jaxpr, reduce_axes, False,
                               jaxpr.consts, primals, cts_out)
     _, cts_in = split_list(cts_in, [num_res])
     return _map(ad.instantiate_zeros_aval, primal_avals, cts_in)
Exemple #2
0
 def transposed(*args):
   in_primals, out_cts = tree_unflatten(treedef, args)
   in_pvals = [pe.PartialVal.unknown(x.aval) if ad.is_undefined_primal(x) else
               pe.PartialVal.known(x) for x in in_primals]
   primal_fun = lu.wrap_init(partial(core.eval_jaxpr, jaxpr, ()))
   tangent_jaxpr, _, consts = pe.trace_to_jaxpr(primal_fun, in_pvals, False)
   dummy_args = [ad.UndefinedPrimal(v.aval) for v in tangent_jaxpr.invars]
   in_cts_ = ad.backward_pass(tangent_jaxpr, reduce_axes, False, consts, dummy_args,
                              out_cts)
   in_cts, cell.treedef = tree_flatten(in_cts_)
   return in_cts
Exemple #3
0
 def read_primal(v):
   if type(v) is Literal:
     return v.val
   else:
     return primal_env.get(v, ad.UndefinedPrimal(v.aval))
Exemple #4
0
def _crown_linear_op(lin_primitive, out_bound, *invals, **kwargs):
    """Backward propagation of LinearBounds through the primitive `lin_primitive`.

  This is achieved by piggybacking on the auto-differentiation code and relying
  on the fact that the backward propagation through linear layer is the
  transpose of the linear operation, which is the same as the operation done
  for gradient backpropagation.

  Args:
    lin_primitive: Jax primitive representing a bilinear operation, to go
      backward through.
    out_bound: CrownBackwardBound, linear function of the network outputs with
      regards to the output activation of this layer.
    *invals: input of the bound propagation in the forward pass
    **kwargs: Dict with the parameters of the linear operation.
  Returns:
    new_in_args: List of CrownBackwardBound
  """
    backward_primitive = ad.get_primitive_transpose(lin_primitive)

    to_backprop = jnp.concatenate(
        [out_bound.lower_lin.lin_coeffs, out_bound.upper_lin.lin_coeffs],
        axis=1)
    nb_coeff_dim = (out_bound.lower_lin.lin_coeffs.ndim -
                    out_bound.lower_lin.offset.ndim)
    nb_output_dim = out_bound.lower_lin.offset.ndim - 1
    cts_in = to_backprop

    unwrapped_invals = []
    for inval in invals:
        if isinstance(inval, ibp.IntervalBound):
            # Create a fake input that would have matched the artificial construct
            # we defined as cts_in
            shape = cts_in.shape[:-nb_coeff_dim] + inval.shape[1:]
            unwrapped_invals.append(ad.UndefinedPrimal(jnp.zeros(shape)))
        elif isinstance(inval, jnp.ndarray):
            unwrapped_invals.append(inval)
        else:
            raise ValueError('Unexpected input for the crown-ibp'
                             f'primitive for {lin_primitive}.')

    vmap_outaxes = tuple(1 if isinstance(arg, ibp.IntervalBound) else None
                         for arg in invals)
    vmap_inaxes = (1, ) + vmap_outaxes
    backward_op = functools.partial(backward_primitive, **kwargs)
    vmap_backward_op = backward_op
    for _ in range(nb_output_dim):
        # Vmap over all the dimensions that we need to pass through.
        vmap_backward_op = jax.vmap(vmap_backward_op,
                                    in_axes=vmap_inaxes,
                                    out_axes=vmap_outaxes)
    cts_out = vmap_backward_op(cts_in, *unwrapped_invals)

    new_in_args = []
    for arg in cts_out:
        if arg is None:
            # This correspond to the input that was a constant, we don't want
            # to propagate anything there.
            new_in_args.append(arg)
        else:
            lower_lin_coeffs, upper_lin_coeffs = jnp.split(arg, 2, axis=1)
            new_in_args.append(
                CrownBackwardBound(
                    LinearExpression(lower_lin_coeffs,
                                     out_bound.lower_lin.offset),
                    LinearExpression(upper_lin_coeffs,
                                     out_bound.upper_lin.offset)))
    return new_in_args
Exemple #5
0
 def read_primal(v):
   if type(v) is core.Literal:
     raise NotImplementedError  # TODO
   else:
     return primal_env.get(v, ad.UndefinedPrimal(v.aval))