コード例 #1
0
    def test_chunking(self, relaxer):
        batch_size = 3
        input_size = 2
        hidden_size = 5
        final_size = 4

        input_shape = (batch_size, input_size)
        hidden_lay_weight_shape = (input_size, hidden_size)
        final_lay_weight_shape = (hidden_size, final_size)

        inp_lb, inp_ub = test_utils.sample_bounds(jax.random.PRNGKey(0),
                                                  input_shape,
                                                  minval=-1.,
                                                  maxval=1.)
        inp_bound = jax_verify.IntervalBound(inp_lb, inp_ub)

        hidden_lay_weight = jax.random.uniform(jax.random.PRNGKey(1),
                                               hidden_lay_weight_shape)
        final_lay_weight = jax.random.uniform(jax.random.PRNGKey(2),
                                              final_lay_weight_shape)

        def model_fun(inp):
            hidden = inp @ hidden_lay_weight
            act = jax.nn.relu(hidden)
            final = act @ final_lay_weight
            return final

        if isinstance(relaxer,
                      linear_bound_utils.ParameterizedLinearBoundsRelaxer):
            concretizing_transform = (
                backward_crown.OptimizingLinearBoundBackwardTransform(
                    relaxer,
                    backward_crown.CONCRETIZE_ARGS_PRIMITIVE,
                    optax.adam(1.e-3),
                    num_opt_steps=10))
        else:
            concretizing_transform = backward_crown.LinearBoundBackwardTransform(
                relaxer, backward_crown.CONCRETIZE_ARGS_PRIMITIVE)

        chunked_concretizer = backward_crown.ChunkedBackwardConcretizer(
            concretizing_transform, max_chunk_size=16)
        unchunked_concretizer = backward_crown.ChunkedBackwardConcretizer(
            concretizing_transform, max_chunk_size=0)

        chunked_algorithm = bound_utils.BackwardConcretizingAlgorithm(
            chunked_concretizer)
        unchunked_algorithm = bound_utils.BackwardConcretizingAlgorithm(
            unchunked_concretizer)

        chunked_bound, _ = bound_propagation.bound_propagation(
            chunked_algorithm, model_fun, inp_bound)
        unchunked_bound, _ = bound_propagation.bound_propagation(
            unchunked_algorithm, model_fun, inp_bound)

        np.testing.assert_array_almost_equal(chunked_bound.lower,
                                             unchunked_bound.lower)
        np.testing.assert_array_almost_equal(chunked_bound.upper,
                                             unchunked_bound.upper)
コード例 #2
0
ファイル: lp_test.py プロジェクト: deepmind/jax_verify
    def solve_with_jax_verify(self):
        lower_bound = jnp.minimum(jnp.maximum(self.inputs - self.eps, 0.0),
                                  1.0)
        upper_bound = jnp.minimum(jnp.maximum(self.inputs + self.eps, 0.0),
                                  1.0)
        init_bound = jax_verify.IntervalBound(lower_bound, upper_bound)

        logits_fn = make_model_fn(self.network_params)

        solver = cvxpy_relaxation_solver.CvxpySolver
        relaxation_transform = relaxation.RelaxationTransform(
            jax_verify.ibp_transform)

        var, env = bound_propagation.bound_propagation(
            bound_propagation.ForwardPropagationAlgorithm(
                relaxation_transform), logits_fn, init_bound)

        # This solver minimizes the objective -> get max with -min(-objective)
        neg_value_opt, _, _ = relaxation.solve_relaxation(
            solver,
            -self.objective,
            -self.objective_bias,
            var,
            env,
            index=0,
            time_limit_millis=None)
        value_opt = -neg_value_opt

        return value_opt
コード例 #3
0
    def get_bounds(self, fun, input_bounds):
        output = fun(input_bounds.lower)

        boundprop_transform = jax_verify.ibp_transform
        relaxation_transform = relaxation.RelaxationTransform(
            boundprop_transform)
        var, env = bound_propagation.bound_propagation(
            bound_propagation.ForwardPropagationAlgorithm(
                relaxation_transform), fun, input_bounds)

        objective_bias = 0.
        index = 0

        lower_bounds = []
        upper_bounds = []
        for output_idx in range(output.size):
            objective = (jnp.arange(output.size) == output_idx).astype(
                jnp.float32)

            lower_bound, _, _ = relaxation.solve_relaxation(
                cvxpy_relaxation_solver.CvxpySolver, objective, objective_bias,
                var, env, index)

            neg_upper_bound, _, _ = relaxation.solve_relaxation(
                cvxpy_relaxation_solver.CvxpySolver, -objective,
                objective_bias, var, env, index)
            lower_bounds.append(lower_bound)
            upper_bounds.append(-neg_upper_bound)

        return jnp.array(lower_bounds), jnp.array(upper_bounds)
コード例 #4
0
def build_nonconvex_formulation(bound_cls, concretizer_ctor, function,
                                *bounds):
    """Builds the optimizable objective.

  Args:
    bound_cls: Bound class to use. This determines what dual formulation will
      be computed.
    concretizer_ctor: Constructor for the concretizer to use to obtain
      intermediate bounds.
    function: Function performing computation to obtain bounds for. Takes as
      only  arguments the network inputs.
    *bounds: jax_verify.NonConvexBound, bounds on the inputs of the function.
      These can be created using
      jax_verify.NonConvexBound.initial_nonconvex_bound.
  Returns:
    output_bound: NonConvex bound that can be optimized with a solver.
  """

    input_transform = functools.partial(bound_cls.initial_nonconvex_bound,
                                        concretizer_ctor)
    primitive_transform = {
        primitive: functools.partial(transform, bound_cls)
        for primitive, transform in _nonconvex_primitive_transform.items()
    }

    bound_transform = bound_propagation.OpwiseBoundTransform(
        input_transform, primitive_transform)
    output_bound, _ = bound_propagation.bound_propagation(
        bound_transform, function, *bounds)
    return output_bound
コード例 #5
0
def crown_bound_propagation(function, *bounds):
    """Performs CROWN as described in https://arxiv.org/abs/1811.00866.

  Args:
    function: Function performing computation to obtain bounds for. Takes as
      only argument the network inputs.
    *bounds: jax_verify.IntervalBound, bounds on the inputs of the function.
  Returns:
    output_bound: Bounds on the output of the function obtained by FastLin
  """
    output_bound, _ = bound_propagation.bound_propagation(
        _crown_transform, function, *bounds)
    return output_bound
コード例 #6
0
ファイル: ibp.py プロジェクト: zeta1999/jax_verify
def interval_bound_propagation(function, *bounds):
    """Performs IBP as described in https://arxiv.org/abs/1810.12715.

  Args:
    function: Function performing computation to obtain bounds for. Takes as
      only argument the network inputs.
    *bounds: jax_verify.IntervalBounds, bounds on the inputs of the function.
  Returns:
    output_bound: Bounds on the output of the function obtained by IBP
  """
    output_bound, _ = bound_propagation.bound_propagation(
        bound_transform, function, *bounds)
    return output_bound
コード例 #7
0
def ibpfastlin_bound_propagation(function, *bounds):
    """Obtains the best of IBP and Fastlin bounds.

  Args:
    function: Function performing computation to obtain bounds for. Takes as
      only argument the network inputs.
    *bounds: jax_verify.IntervalBound, bounds on the inputs of the function.
  Returns:
    output_bound: Bounds on the output of the function obtained by FastLin
  """
    output_bound, _ = bound_propagation.bound_propagation(
        intersection.IntersectionBoundTransform(ibp.bound_transform,
                                                fastlin_transform), function,
        *bounds)
    return output_bound
コード例 #8
0
def solve_planet_relaxation(logits_fn,
                            initial_bounds,
                            boundprop_transform,
                            objective,
                            objective_bias,
                            index,
                            solver=cvxpy_relaxation_solver.CvxpySolver):
    """Solves the "Planet" (Ehlers 17) or "triangle" relaxation.

  The general approach is to use jax_verify to generate constraints, which can
  then be passed to generic solvers. Note that using CVXPY will incur a large
  overhead when defining the LP, because we define all constraints element-wise,
  to avoid representing convolutional layers as a single matrix multiplication,
  which would be inefficient. In CVXPY, defining large numbers of constraints is
  slow.

  Args:
    logits_fn: Mapping from inputs (batch_size x input_size) -> (batch_size,
      num_classes)
    initial_bounds: `IntervalBound` with initial bounds on inputs,
      with lower and upper bounds of dimension (batch_size x input_size).
    boundprop_transform: bound_propagation.BoundTransform instance, such as
      `jax_verify.ibp_transform`. Used to pre-compute interval bounds for
      intermediate activations used in defining the Planet relaxation.
    objective: Objective to optimize, given as an array of coefficients to be
      applied to the output of logits_fn defining the objective to minimize
    objective_bias: Bias to add to objective
    index: Index in the batch for which to solve the relaxation
    solver: A relaxation.RelaxationSolver, which specifies the backend to solve
      the resulting LP.
  Returns:
    val: The optimal value from the relaxation
    solution: The optimal solution found by the solver
    status: The status of the relaxation solver
  """
    relaxation_transform = relaxation.RelaxationTransform(boundprop_transform)
    variable, env = bound_propagation.bound_propagation(
        bound_propagation.ForwardPropagationAlgorithm(relaxation_transform),
        logits_fn, initial_bounds)
    value, solution, status = relaxation.solve_relaxation(
        solver,
        objective,
        objective_bias,
        variable,
        env,
        index=index,
        time_limit_millis=None)
    return value, solution, status
コード例 #9
0
    def __init__(
        self,
        boundprop_transform: bound_propagation.BoundTransform,
        spec_fn: Callable[..., Tensor],
        *input_bounds: Bound,
    ):
        """Initialises a ReLU-based network SDP problem.

    Args:
      boundprop_transform: Transform to supply concrete bounds.
      spec_fn: Network to verify.
      *input_bounds: Concrete bounds on the network inputs.
    """
        self._output_node, self._env = bound_propagation.bound_propagation(
            bound_propagation.ForwardPropagationAlgorithm(
                _SdpTransform(boundprop_transform)), spec_fn, *input_bounds)
コード例 #10
0
def backward_fastlin_bound_propagation(
    function: Callable[..., Nest[Tensor]],
    *bounds: Nest[GraphInput]) -> Nest[LayerInput]:
  """Performs FastLin as described in https://arxiv.org/abs/1804.09699.

  Args:
    function: Function performing computation to obtain bounds for. Takes as
      only argument the network inputs.
    *bounds: jax_verify.IntervalBound, bounds on the inputs of the function.
  Returns:
    output_bound: Bounds on the output of the function obtained by FastLin
  """
  backward_fastlin_algorithm = bound_utils.BackwardConcretizingAlgorithm(
      backward_fastlin_concretizer)
  output_bound, _ = bound_propagation.bound_propagation(
      backward_fastlin_algorithm, function, *bounds)
  return output_bound
コード例 #11
0
def forward_fastlin_bound_propagation(function, *bounds):
    """Performs forward linear bound propagation.

  This is using the relu relaxation of Fastlin.
  (https://arxiv.org/abs/1804.09699)

  Args:
    function: Function performing computation to obtain bounds for. Takes as
      only argument the network inputs.
    *bounds: jax_verify.IntervalBound, bounds on the inputs of the function.
  Returns:
    output_bound: Bounds on the output of the function obtained by FastLin
  """
    output_bound, _ = bound_propagation.bound_propagation(
        bound_propagation.ForwardPropagationAlgorithm(
            forward_fastlin_transform), function, *bounds)
    return output_bound
コード例 #12
0
    def test_equal_bounds_parameterized(self):
        model = jax.nn.relu
        sample_value = jnp.array([-1., 1.])
        inp_bound = jax_verify.IntervalBound(sample_value, sample_value)

        concretizer = backward_crown.ChunkedBackwardConcretizer(
            backward_crown.OptimizingLinearBoundBackwardTransform(
                linear_bound_utils.parameterized_relaxer,
                backward_crown.CONCRETIZE_ARGS_PRIMITIVE,
                optax.adam(1.e-3),
                num_opt_steps=10))

        algorithm = bound_utils.BackwardConcretizingAlgorithm(concretizer)
        bound, _ = bound_propagation.bound_propagation(algorithm, model,
                                                       inp_bound)

        np.testing.assert_array_almost_equal(bound.lower, bound.upper)
コード例 #13
0
def forward_crown_bound_propagation(function, *bounds):
    """Performs forward linear bound propagation.

  This is using the relu relaxation of CROWN.
  (https://arxiv.org/abs/1811.00866)

  Args:
    function: Function performing computation to obtain bounds for. Takes as
      only argument the network inputs.
    *bounds: jax_verify.IntervalBound, bounds on the inputs of the function.
  Returns:
    output_bound: Bounds on the output of the function obtained by FastLin
  """
    forward_crown_transform = ForwardLinearBoundTransform(
        linear_bound_utils.crown_rvt_relaxer)
    output_bound, _ = bound_propagation.bound_propagation(
        bound_propagation.ForwardPropagationAlgorithm(forward_crown_transform),
        function, *bounds)
    return output_bound
コード例 #14
0
def crownibp_bound_propagation(
    function: Callable[..., Nest[Tensor]],
    *bounds: Nest[GraphInput]) -> Nest[LayerInput]:
  """Performs Crown-IBP as described in https://arxiv.org/abs/1906.06316.

  We first perform IBP to obtain intermediate bounds and then propagate linear
  bounds backwards.

  Args:
    function: Function performing computation to obtain bounds for. Takes as
      only argument the network inputs.
    *bounds: jax_verify.IntervalBounds, bounds on the inputs of the function.
  Returns:
    output_bounds: Bounds on the outputs of the function obtained by Crown-IBP
  """
  crown_ibp_algorithm = bound_utils.BackwardAlgorithmForwardConcretization(
      ibp.bound_transform, backward_crown_concretizer)
  output_bounds, _ = bound_propagation.bound_propagation(
      crown_ibp_algorithm, function, *bounds)
  return output_bounds
コード例 #15
0
def nonconvex_ibp_bound_propagation(
    function: Callable[..., Nest[Tensor]],
    *bounds: Nest[graph_traversal.GraphInput],
    graph_simplifier=synthetic_primitives.default_simplifier,
) -> Nest[nonconvex.NonConvexBound]:
    """Builds the non-convex objective using IBP.

  Args:
    function: Function performing computation to obtain bounds for. Takes as
      only arguments the network inputs.
    *bounds: Bounds on the inputs of the function.
    graph_simplifier: What graph simplifier to use.
  Returns:
    output_bounds: NonConvex bounds that can be optimized with a solver.
  """
    algorithm = nonconvex.nonconvex_algorithm(
        duals.WolfeNonConvexBound,
        nonconvex.BaseBoundConcretizer(),
        base_boundprop=ibp.bound_transform)
    output_bounds, _ = bound_propagation.bound_propagation(
        algorithm, function, *bounds, graph_simplifier=graph_simplifier)
    return output_bounds
コード例 #16
0
def _nonconvex_boundprop(params,
                         x,
                         epsilon,
                         input_bounds,
                         nonconvex_boundprop_steps=100,
                         nonconvex_boundprop_nodes=128):
    """Wrapper for nonconvex bound propagation."""
    # Get initial bounds for boundprop
    init_bounds = utils.init_bound(x,
                                   epsilon,
                                   input_bounds=input_bounds,
                                   add_batch_dim=False)

    # Build fn to boundprop through
    all_act_fun = functools.partial(utils.predict_cnn,
                                    params,
                                    include_preactivations=True)

    # Collect the intermediate bounds.
    input_bound = jax_verify.IntervalBound(init_bounds.lb, init_bounds.ub)

    optimizer = optimizers.OptimizingConcretizer(
        FistaOptimizer(num_steps=nonconvex_boundprop_steps),
        max_parallel_nodes=nonconvex_boundprop_nodes)
    nonconvex_algorithm = nonconvex.nonconvex_algorithm(
        duals.WolfeNonConvexBound, optimizer)

    all_outputs, _ = bound_propagation.bound_propagation(
        nonconvex_algorithm, all_act_fun, input_bound)
    _, intermediate_nonconvex_bounds = all_outputs

    bounds = [init_bounds]
    for nncvx_bound in intermediate_nonconvex_bounds:
        bounds.append(
            utils.IntBound(lb_pre=nncvx_bound.lower,
                           ub_pre=nncvx_bound.upper,
                           lb=jnp.maximum(nncvx_bound.lower, 0),
                           ub=jnp.maximum(nncvx_bound.upper, 0)))
    return bounds
コード例 #17
0
ファイル: crown_ibp.py プロジェクト: zeta1999/jax_verify
def crownibp_bound_propagation(function, bounds):
    """Performs Crown-IBP as described in https://arxiv.org/abs/1906.06316.

  We first perform IBP to obtain intermediate bounds and then propagate linear
  bounds backwards.

  Args:
    function: Function performing computation to obtain bounds for. Takes as
      only argument the network inputs.
    bounds: jax_verify.IntervalBounds, bounds on the inputs of the function.
  Returns:
    output_bound: Bounds on the output of the function obtained by Crown-IBP
  """
    ibp_bound, graph = bound_propagation.bound_propagation(
        ibp.bound_transform, function, bounds)

    # Define the initial bound to propagate backward
    assert hasattr(ibp_bound, 'shape'), (
        'crownibp_bound_propagation requires `function` to output a single array '
        'as opposed to an arbitrary pytree')
    batch_size = ibp_bound.shape[0]
    act_shape = ibp_bound.shape[1:]

    nb_act = np.prod(act_shape)
    identity_lin_coeffs = jnp.reshape(jnp.eye(nb_act), act_shape + act_shape)
    initial_lin_coeffs = jnp.repeat(jnp.expand_dims(identity_lin_coeffs, 0),
                                    batch_size,
                                    axis=0)
    initial_offsets = jnp.zeros_like(ibp_bound.lower)

    initial_backward_bound = CrownBackwardBound(
        LinearExpression(initial_lin_coeffs, initial_offsets),
        LinearExpression(initial_lin_coeffs, initial_offsets))

    input_fun, = graph.backward_propagation(_primitive_transform, sum,
                                            initial_backward_bound)

    return concretize_backward_bound(input_fun, bounds)
コード例 #18
0
ファイル: model_zoo_test.py プロジェクト: deepmind/jax_verify
    def test_cvxpy_relaxation(self, model_cls):
        @hk.transform_with_state
        def model_pred(inputs, is_training, test_local_stats=False):
            model = model_cls()
            return model(inputs, is_training, test_local_stats)

        inps = jnp.zeros((4, 28, 28, 1), dtype=jnp.float32)
        params, state = model_pred.init(jax.random.PRNGKey(42),
                                        inps,
                                        is_training=True)

        def logits_fun(inputs):
            return model_pred.apply(params,
                                    state,
                                    None,
                                    inputs,
                                    False,
                                    test_local_stats=False)[0]

        output = logits_fun(inps)
        input_bounds = jax_verify.IntervalBound(inps - 1.0, inps + 1.0)

        boundprop_transform = jax_verify.ibp_transform
        relaxation_transform = relaxation.RelaxationTransform(
            boundprop_transform)
        var, env = bound_propagation.bound_propagation(
            bound_propagation.ForwardPropagationAlgorithm(
                relaxation_transform), logits_fun, input_bounds)

        objective_bias = 0.
        objective = jnp.zeros(output.shape[1:]).at[0].set(1)
        index = 0

        lower_bound, _, _ = relaxation.solve_relaxation(
            cvxpy_relaxation_solver.CvxpySolver, objective, objective_bias,
            var, env, index)

        self.assertLessEqual(lower_bound, output[index, 0])
コード例 #19
0
def nonconvex_constopt_bound_propagation(
    function: Callable[..., Nest[Tensor]],
    *bounds: Nest[graph_traversal.GraphInput],
    graph_simplifier=synthetic_primitives.default_simplifier,
) -> Nest[nonconvex.NonConvexBound]:
    """Builds the optimizable objective.

  Args:
    function: Function performing computation to obtain bounds for. Takes as
      only arguments the network inputs.
    *bounds: Bounds on the inputs of the function.
    graph_simplifier: What graph simplifier to use.
  Returns:
    output_bounds: NonConvex bounds that can be optimized with a solver.
  """
    nostep_optimizer = optimizers.OptimizingConcretizer(
        optimizers.PGDOptimizer(0, 0., optimize_dual=False),
        max_parallel_nodes=512)
    algorithm = nonconvex.nonconvex_algorithm(duals.WolfeNonConvexBound,
                                              nostep_optimizer)
    output_bounds, _ = bound_propagation.bound_propagation(
        algorithm, function, *bounds, graph_simplifier=graph_simplifier)
    return output_bounds
コード例 #20
0
def backward_rvt_bound_propagation(
    function: Callable[..., Nest[Tensor]],
    *bounds: Nest[GraphInput]) -> Nest[LayerInput]:
  """Performs CROWN as described in https://arxiv.org/abs/1811.00866.

  Args:
    function: Function performing computation to obtain bounds for. Takes as
      only argument the network inputs.
    *bounds: jax_verify.IntervalBound, bounds on the inputs of the function.
  Returns:
    output_bound: Bounds on the output of the function obtained by FastLin
  """
  backward_crown_algorithm = bound_utils.BackwardConcretizingAlgorithm(
      backward_crown_concretizer)
  expand_softmax_simplifier_chain = synthetic_primitives.simplifier_composition(
      synthetic_primitives.activation_simplifier,
      synthetic_primitives.hoist_constant_computations,
      synthetic_primitives.expand_softmax_simplifier,
      synthetic_primitives.group_linear_sequence,
      synthetic_primitives.group_posbilinear)
  output_bound, _ = bound_propagation.bound_propagation(
      backward_crown_algorithm, function, *bounds,
      graph_simplifier=expand_softmax_simplifier_chain)
  return output_bound
コード例 #21
0
 def bound_prop(function, *bounds) -> forward_linear_bounds.LinearBound:
     output_bound, _ = bound_propagation.bound_propagation(
         algorithm, function, *bounds)
     return output_bound