Exemplo n.º 1
0
    def propagate(
        self,
        graph: bound_propagation.PropagationGraph,
        bounds: Nest[graph_traversal.GraphInput],
    ) -> Tuple[Nest[Bound], Dict[jax.core.Var, Union[Bound, Tensor]]]:
        if self._base_boundprop is not None:
            # Propagate the 'base' bounds in advance, for subsequent use by
            # the concretiser.
            _, base_env = bound_propagation.ForwardPropagationAlgorithm(
                self._base_boundprop).propagate(graph, bounds)
        else:
            # No 'base' boundprop method specified.
            # This is fine as long as the concretiser does not rely on base bounds.
            base_env = None

        nonconvex_transform = self._nonconvex_transform_ctor(
            self._concretizer, graph, base_env)
        output_bounds, env = bound_propagation.ForwardPropagationAlgorithm(
            nonconvex_transform).propagate(graph, bounds)

        # Always concretize the returned bounds so that `lower` and `upper` are
        # accessible
        for out_bound in jax.tree_util.tree_leaves(output_bounds):
            if out_bound.requires_concretizing(None):
                out_bound.concretize(self._concretizer, graph, base_env)

        return output_bounds, env
Exemplo n.º 2
0
    def get_bounds(self, fun, input_bounds):
        output = fun(input_bounds.lower)

        boundprop_transform = jax_verify.ibp_transform
        relaxation_transform = relaxation.RelaxationTransform(
            boundprop_transform)
        var, env = bound_propagation.bound_propagation(
            bound_propagation.ForwardPropagationAlgorithm(
                relaxation_transform), fun, input_bounds)

        objective_bias = 0.
        index = 0

        lower_bounds = []
        upper_bounds = []
        for output_idx in range(output.size):
            objective = (jnp.arange(output.size) == output_idx).astype(
                jnp.float32)

            lower_bound, _, _ = relaxation.solve_relaxation(
                cvxpy_relaxation_solver.CvxpySolver, objective, objective_bias,
                var, env, index)

            neg_upper_bound, _, _ = relaxation.solve_relaxation(
                cvxpy_relaxation_solver.CvxpySolver, -objective,
                objective_bias, var, env, index)
            lower_bounds.append(lower_bound)
            upper_bounds.append(-neg_upper_bound)

        return jnp.array(lower_bounds), jnp.array(upper_bounds)
Exemplo n.º 3
0
    def solve_with_jax_verify(self):
        lower_bound = jnp.minimum(jnp.maximum(self.inputs - self.eps, 0.0),
                                  1.0)
        upper_bound = jnp.minimum(jnp.maximum(self.inputs + self.eps, 0.0),
                                  1.0)
        init_bound = jax_verify.IntervalBound(lower_bound, upper_bound)

        logits_fn = make_model_fn(self.network_params)

        solver = cvxpy_relaxation_solver.CvxpySolver
        relaxation_transform = relaxation.RelaxationTransform(
            jax_verify.ibp_transform)

        var, env = bound_propagation.bound_propagation(
            bound_propagation.ForwardPropagationAlgorithm(
                relaxation_transform), logits_fn, init_bound)

        # This solver minimizes the objective -> get max with -min(-objective)
        neg_value_opt, _, _ = relaxation.solve_relaxation(
            solver,
            -self.objective,
            -self.objective_bias,
            var,
            env,
            index=0,
            time_limit_millis=None)
        value_opt = -neg_value_opt

        return value_opt
Exemplo n.º 4
0
def interval_bound_propagation(function, *bounds):
    """Performs IBP as described in https://arxiv.org/abs/1810.12715.

  Args:
    function: Function performing computation to obtain bounds for. Takes as
      only argument the network inputs.
    *bounds: jax_verify.IntervalBounds, bounds on the inputs of the function.
  Returns:
    output_bound: Bounds on the output of the function obtained by IBP
  """
    output_bound, _ = bound_propagation.bound_propagation(
        bound_propagation.ForwardPropagationAlgorithm(bound_transform),
        function, *bounds)
    return output_bound
Exemplo n.º 5
0
  def propagate(self,
                graph: bound_propagation.PropagationGraph,
                inputs: Nest[GraphInput]):

    # Inspect the graph to figure out what are the nodes needing concretization.
    graph_inspector = bound_utils.GraphInspector()
    inspector_algorithm = bound_propagation.ForwardPropagationAlgorithm(
        graph_inspector)
    gn_outvals, env = inspector_algorithm.propagate(graph, inputs)

    flat_inputs, _ = jax.tree_util.tree_flatten(inputs)
    flat_bounds = [inp for inp in flat_inputs
                   if isinstance(inp, bound_propagation.Bound)]
    input_nodes_indices = [(i,) for i in range(len(flat_bounds))]

    # For every node that requires relaxations, we will use a RelaxationScanner
    # to collect the node that it requires.
    relaxations = {}
    def collect_relaxations(graph_node):
      if graph_node.index not in relaxations:
        index_to_concretize = graph_node.index
        jaxpr_node = graph.jaxpr_node(index_to_concretize)
        scanner = _RelaxationScanner(self._relaxer)
        graph.backward_propagation(
            scanner, env, {jaxpr_node: graph_node},
            input_nodes_indices)
        relaxations[index_to_concretize] = scanner.node_relaxations

    for node in graph_inspector.nodes.values():
      node_primitive = node.primitive
      if node_primitive and node_primitive in CONCRETIZE_ARGS_PRIMITIVE:
        for node_arg in node.args:
          collect_relaxations(node_arg)

    # Iterate over the outputs, making notes of their index so that we can use
    # them to specify the objective function, and collecting the relaxations we
    # need to define to use them.
    objective_nodes = []
    for gn in gn_outvals:
      collect_relaxations(gn)
      jaxpr_node = graph.jaxpr_node(gn.index)
      objective_nodes.append(jaxpr_node)

    env_with_final_bounds = self.jointly_optimize_relaxations(
        relaxations, graph, inputs, env, objective_nodes)

    outvals = [env_with_final_bounds[jaxpr_node_opted]
               for jaxpr_node_opted in objective_nodes]

    return outvals, env_with_final_bounds
Exemplo n.º 6
0
def solve_planet_relaxation(logits_fn,
                            initial_bounds,
                            boundprop_transform,
                            objective,
                            objective_bias,
                            index,
                            solver=cvxpy_relaxation_solver.CvxpySolver):
    """Solves the "Planet" (Ehlers 17) or "triangle" relaxation.

  The general approach is to use jax_verify to generate constraints, which can
  then be passed to generic solvers. Note that using CVXPY will incur a large
  overhead when defining the LP, because we define all constraints element-wise,
  to avoid representing convolutional layers as a single matrix multiplication,
  which would be inefficient. In CVXPY, defining large numbers of constraints is
  slow.

  Args:
    logits_fn: Mapping from inputs (batch_size x input_size) -> (batch_size,
      num_classes)
    initial_bounds: `IntervalBound` with initial bounds on inputs,
      with lower and upper bounds of dimension (batch_size x input_size).
    boundprop_transform: bound_propagation.BoundTransform instance, such as
      `jax_verify.ibp_transform`. Used to pre-compute interval bounds for
      intermediate activations used in defining the Planet relaxation.
    objective: Objective to optimize, given as an array of coefficients to be
      applied to the output of logits_fn defining the objective to minimize
    objective_bias: Bias to add to objective
    index: Index in the batch for which to solve the relaxation
    solver: A relaxation.RelaxationSolver, which specifies the backend to solve
      the resulting LP.
  Returns:
    val: The optimal value from the relaxation
    solution: The optimal solution found by the solver
    status: The status of the relaxation solver
  """
    relaxation_transform = relaxation.RelaxationTransform(boundprop_transform)
    variable, env = bound_propagation.bound_propagation(
        bound_propagation.ForwardPropagationAlgorithm(relaxation_transform),
        logits_fn, initial_bounds)
    value, solution, status = relaxation.solve_relaxation(
        solver,
        objective,
        objective_bias,
        variable,
        env,
        index=index,
        time_limit_millis=None)
    return value, solution, status
Exemplo n.º 7
0
def ibpforwardfastlin_bound_propagation(function, *bounds):
    """Obtains the best of IBP and ForwardFastlin bounds.

  Args:
    function: Function performing computation to obtain bounds for. Takes as
      only argument the network inputs.
    *bounds: jax_verify.IntervalBound, bounds on the inputs of the function.
  Returns:
    output_bound: Bounds on the output of the function obtained by FastLin
  """
    output_bound, _ = bound_propagation.bound_propagation(
        bound_propagation.ForwardPropagationAlgorithm(
            intersection.IntersectionBoundTransform(
                ibp.bound_transform, forward_fastlin_transform)), function,
        *bounds)
    return output_bound
Exemplo n.º 8
0
    def __init__(
        self,
        boundprop_transform: bound_propagation.BoundTransform,
        spec_fn: Callable[..., Tensor],
        *input_bounds: Bound,
    ):
        """Initialises a ReLU-based network SDP problem.

    Args:
      boundprop_transform: Transform to supply concrete bounds.
      spec_fn: Network to verify.
      *input_bounds: Concrete bounds on the network inputs.
    """
        self._output_node, self._env = bound_propagation.bound_propagation(
            bound_propagation.ForwardPropagationAlgorithm(
                _SdpTransform(boundprop_transform)), spec_fn, *input_bounds)
Exemplo n.º 9
0
def forward_fastlin_bound_propagation(function, *bounds):
    """Performs forward linear bound propagation.

  This is using the relu relaxation of Fastlin.
  (https://arxiv.org/abs/1804.09699)

  Args:
    function: Function performing computation to obtain bounds for. Takes as
      only argument the network inputs.
    *bounds: jax_verify.IntervalBound, bounds on the inputs of the function.
  Returns:
    output_bound: Bounds on the output of the function obtained by FastLin
  """
    output_bound, _ = bound_propagation.bound_propagation(
        bound_propagation.ForwardPropagationAlgorithm(
            forward_fastlin_transform), function, *bounds)
    return output_bound
Exemplo n.º 10
0
    def propagate(
        self, graph: bound_propagation.PropagationGraph,
        inputs: Nest[GraphInput]
    ) -> Tuple[Nest[LayerInput], Dict[jax.core.Var, LayerInput]]:
        subgraph_decider = self._backward_concretizer.should_handle_as_subgraph
        graph_inspector = GraphInspector(subgraph_decider)
        inspector_algorithm = bound_propagation.ForwardPropagationAlgorithm(
            graph_inspector)
        gn_outvals, env = inspector_algorithm.propagate(graph, inputs)

        for node in graph_inspector.nodes.values():
            # Iterate over the nodes in order so that we get intermediate bounds in
            # the order where we need them.
            node_primitive = node.primitive
            if (node_primitive and self._backward_concretizer.concretize_args(
                    node.primitive)):
                for node_arg in node.args:
                    if isinstance(node_arg, GraphNode):
                        node_index_to_concretize = node_arg.index
                        jaxpr_node = graph.jaxpr_node(node_index_to_concretize)
                        node_to_concretize = env[jaxpr_node]
                        if isinstance(node_to_concretize, GraphNode):
                            # This node has not yet been concretized. Perform concretization.
                            concrete_bound = self._backward_concretizer.concrete_bound(
                                graph, inputs, env, jaxpr_node)
                            env[jaxpr_node] = concrete_bound

        # Fill in the bounds for the inputs.
        for in_jaxpr_node, in_bound in zip(graph.inputs, inputs):
            env[in_jaxpr_node] = in_bound
        # Iterate over the outputs, making sure to concretize all of them.
        outvals = []
        for gn in gn_outvals:
            jaxpr_node = graph.jaxpr_node(gn.index)
            env_node = env[jaxpr_node]
            # This node has not been concretized. Perform concretization.
            if isinstance(env_node, GraphNode):
                concrete_bound = self._backward_concretizer.concrete_bound(
                    graph, inputs, env, jaxpr_node)
                env[jaxpr_node] = concrete_bound
            else:
                concrete_bound = env_node
            outvals.append(concrete_bound)

        return outvals, env
def get_boundprop(
        name: str,
        elision: bool) -> Callable[..., forward_linear_bounds.LinearBound]:
    if name == 'fastlin':
        relaxer = linear_bound_utils.fastlin_rvt_relaxer
    elif name == 'crown':
        relaxer = linear_bound_utils.crown_rvt_relaxer

    transform = forward_linear_bounds.ForwardLinearBoundTransform(
        relaxer, elision)
    algorithm = bound_propagation.ForwardPropagationAlgorithm(transform)

    def bound_prop(function, *bounds) -> forward_linear_bounds.LinearBound:
        output_bound, _ = bound_propagation.bound_propagation(
            algorithm, function, *bounds)
        return output_bound

    return bound_prop
Exemplo n.º 12
0
def forward_crown_bound_propagation(function, *bounds):
    """Performs forward linear bound propagation.

  This is using the relu relaxation of CROWN.
  (https://arxiv.org/abs/1811.00866)

  Args:
    function: Function performing computation to obtain bounds for. Takes as
      only argument the network inputs.
    *bounds: jax_verify.IntervalBound, bounds on the inputs of the function.
  Returns:
    output_bound: Bounds on the output of the function obtained by FastLin
  """
    forward_crown_transform = ForwardLinearBoundTransform(
        linear_bound_utils.crown_rvt_relaxer)
    output_bound, _ = bound_propagation.bound_propagation(
        bound_propagation.ForwardPropagationAlgorithm(forward_crown_transform),
        function, *bounds)
    return output_bound
Exemplo n.º 13
0
    def test_cvxpy_relaxation(self, model_cls):
        @hk.transform_with_state
        def model_pred(inputs, is_training, test_local_stats=False):
            model = model_cls()
            return model(inputs, is_training, test_local_stats)

        inps = jnp.zeros((4, 28, 28, 1), dtype=jnp.float32)
        params, state = model_pred.init(jax.random.PRNGKey(42),
                                        inps,
                                        is_training=True)

        def logits_fun(inputs):
            return model_pred.apply(params,
                                    state,
                                    None,
                                    inputs,
                                    False,
                                    test_local_stats=False)[0]

        output = logits_fun(inps)
        input_bounds = jax_verify.IntervalBound(inps - 1.0, inps + 1.0)

        boundprop_transform = jax_verify.ibp_transform
        relaxation_transform = relaxation.RelaxationTransform(
            boundprop_transform)
        var, env = bound_propagation.bound_propagation(
            bound_propagation.ForwardPropagationAlgorithm(
                relaxation_transform), logits_fun, input_bounds)

        objective_bias = 0.
        objective = jnp.zeros(output.shape[1:]).at[0].set(1)
        index = 0

        lower_bound, _, _ = relaxation.solve_relaxation(
            cvxpy_relaxation_solver.CvxpySolver, objective, objective_bias,
            var, env, index)

        self.assertLessEqual(lower_bound, output[index, 0])
Exemplo n.º 14
0
 def __init__(self, forward_transform: bound_propagation.BoundTransform,
              backward_concretizer: BackwardConcretizer):
     self._forward_algorithm = bound_propagation.ForwardPropagationAlgorithm(
         forward_transform)
     self._backward_concretizer = backward_concretizer
Exemplo n.º 15
0
  def concrete_bound_chunk(
      self,
      graph: bound_propagation.PropagationGraph,
      inputs: Nest[GraphInput],
      env: Dict[jax.core.Var, LayerInput],
      node_ref: jax.core.Var,
      obj: Tensor,
  ) -> Tensor:
    # Analyse the relevant parts of the graph.
    flat_inputs, _ = jax.tree_util.tree_flatten(inputs)
    bound_inputs = [inp for inp in flat_inputs
                    if isinstance(inp, bound_propagation.Bound)]
    input_nodes_indices = [(i,) for i in range(len(bound_inputs))]
    scanner = _RelaxationScanner(self._relaxer)
    graph.backward_propagation(
        scanner, env, {node_ref: env[node_ref]}, input_nodes_indices)

    # Allow lookup of any node's input bounds, for parameter initialisation.
    graph_inspector = bound_utils.GraphInspector()
    bound_propagation.ForwardPropagationAlgorithm(
        graph_inspector).propagate(graph, inputs)

    def input_bounds(index: Index) -> Sequence[LayerInput]:
      graph_node = graph_inspector.nodes[index]
      return [env[graph.jaxpr_node(arg.index)]
              if isinstance(arg, bound_utils.GraphNode) else arg
              for arg in graph_node.args]

    # Define optimisation for a single neuron's bound. (We'll vmap afterwards.)
    # This ensures that each neuron uses independent relaxation parameters.
    def optimized_concrete_bound(one_obj):
      def concrete_bound(relax_params):
        return self._bind(
            scanner.node_relaxations, relax_params).concrete_bound_chunk(
                graph, inputs, env, node_ref, jnp.expand_dims(one_obj, 0))

      # Define function to optimise: summary tightness of guaranteed bounds.
      def objective(relax_params):
        lb_min = concrete_bound(relax_params)
        return jnp.sum(-lb_min)

      val_and_grad_fn = jax.value_and_grad(objective)

      # Optimise the relaxation parameters.
      initial_params = self._initial_params(scanner, input_bounds)
      initial_state = (initial_params, self._opt.init(initial_params),
                       initial_params, jnp.inf)
      def update_state(_, state):
        params, opt_state, best_params, best_val = state
        params_val, params_grad = val_and_grad_fn(params)
        # Compute the next step in the optimization process.
        updates, next_opt_state = self._opt.update(params_grad, opt_state)
        next_params = optax.apply_updates(params, updates)
        next_params = self._project_params(scanner, next_params)
        # Update the best params seen.
        params_improved = params_val < best_val
        update_best_params = lambda p, best: jnp.where(params_improved, p, best)

        next_best_params = jax.tree_multimap(update_best_params,
                                             params, best_params)
        next_best_val = jnp.minimum(best_val, params_val)
        return next_params, next_opt_state, next_best_params, next_best_val
      _, _, relax_params, _ = jax.lax.fori_loop(
          0, self._num_opt_steps, update_state, initial_state)

      # Evaluate the relaxation at these parameters.
      return concrete_bound(jax.lax.stop_gradient(relax_params))

    return jax.vmap(optimized_concrete_bound)(obj)