예제 #1
0
    def test_integration_combined_layer(self):
        prng_keys = jax.random.split(self.prng_key, 4)

        dim_1 = X_SHAPE[1]
        dim_2 = dim_1 + 1
        weights_1 = jax.random.normal(prng_keys[0], [dim_1, dim_2])
        bias_1 = jnp.zeros([dim_2])
        lp_pre = jnp.abs(jax.random.normal(prng_keys[1], [1, dim_1]))

        lagrangian_form = QuadDiagForm()

        weights_2 = jax.random.normal(prng_keys[2], [dim_2, 1])
        bias_2 = jnp.zeros([])
        bounds_2 = [
            sdp_utils.IntBound(lb=-jnp.ones([1, dim_2]),
                               ub=jnp.ones([1, dim_2]),
                               lb_pre=None,
                               ub_pre=None),
            sdp_utils.IntBound(lb=None, ub=None, lb_pre=None, ub_pre=None),
        ]

        opt_instance_1 = verify_utils.InnerVerifInstance(
            affine_fns=[lambda x: x @ weights_1 + bias_1],
            bounds=self.bounds,
            lagrangian_form_pre=lagrangian_form,
            lagrangian_form_post=lagrangian_form,
            is_first=False,
            is_last=False,
            lagrange_params_pre=lp_pre,
            lagrange_params_post=None,
            idx=0,
            spec_type=verify_utils.SpecType.ADVERSARIAL,
            affine_before_relu=True)

        opt_instance_2 = verify_utils.InnerVerifInstance(
            affine_fns=[lambda x: x @ weights_2 + bias_2],
            bounds=bounds_2,
            lagrangian_form_pre=lagrangian_form,
            lagrangian_form_post=lagrangian_form,
            is_first=False,
            is_last=True,
            lagrange_params_pre=None,
            lagrange_params_post=None,
            idx=1,
            spec_type=verify_utils.SpecType.ADVERSARIAL,
            affine_before_relu=True)

        opt_instance = dual_build._merge_instances(
            opt_instance_1,
            opt_instance_2,
        )

        # run PGA on problem
        pga_opt = pga.PgaStrategy(n_iter=5, lr=0.01)
        pga_opt.solve_max(inner_dual_vars=None,
                          opt_instance=opt_instance,
                          key=prng_keys[3],
                          step=None)
예제 #2
0
    def setUp(self):
        super().setUp()

        self.prng_key = jax.random.PRNGKey(1234)

        self.bounds = [
            sdp_utils.IntBound(lb=-jnp.ones(X_SHAPE),
                               ub=jnp.ones(X_SHAPE),
                               lb_pre=None,
                               ub_pre=None),
            sdp_utils.IntBound(lb=None, ub=None, lb_pre=None, ub_pre=None),
        ]
예제 #3
0
 def _bounds(self) -> Sequence[utils.IntBound]:
     return [
         utils.IntBound(lb=node.lower,
                        ub=node.upper,
                        lb_pre=None,
                        ub_pre=None) for node in self._env.values()
         if isinstance(node, Bound) and not node.is_affine
     ]
예제 #4
0
    def setUp(self):
        super(UncertaintySpecTest, self).setUp()

        self._prng_seq = hk.PRNGSequence(13579)
        self._n_classes = X_SHAPE[1]

        self.bounds = [
            sdp_utils.IntBound(lb_pre=-0.1 * jnp.ones(X_SHAPE),
                               ub_pre=0.1 * jnp.ones(X_SHAPE),
                               lb=None,
                               ub=None)
        ]
예제 #5
0
def make_toy_verif_instance(seed=None,
                            label=None,
                            target_label=None,
                            nn='mlp'):
    """Mainly used for unit testing."""
    key = jax.random.PRNGKey(0) if seed is None else jax.random.PRNGKey(seed)
    if nn == 'mlp':
        layer_sizes = '5, 5, 5'
        layer_sizes = np.fromstring(layer_sizes, dtype=int, sep=',')
        params = make_mlp_params(layer_sizes, key)
        inp_shape = (1, layer_sizes[0])
    else:
        if nn == 'cnn_simple':
            pad = 'VALID'
            # Input and filter size match -> filter is applied at just one location.

        else:
            pad = 'SAME'
            # Input is padded on right/bottom to form 3x3 input

        layer_sizes = [(1, 2, 2, 1), {
            'n_h': 2,
            'n_w': 2,
            'n_cout': 2,
            'padding': pad,
            'stride': 1,
            'n_cin': 1
        }, 3]
        inp_shape = layer_sizes[0]
        params = make_cnn_params(layer_sizes, key)

    bounds = utils.boundprop(
        params,
        utils.IntBound(lb=np.zeros(inp_shape),
                       ub=1 * np.ones(inp_shape),
                       lb_pre=None,
                       ub_pre=None))
    target_label = 1 if target_label is None else target_label
    label = 2 if label is None else label
    verif_instance = utils.make_nn_verif_instance(params,
                                                  bounds,
                                                  target_label=target_label,
                                                  label=label,
                                                  input_bounds=(0., 1.))
    return verif_instance
예제 #6
0
def _compute_jv_bounds(
    input_bound: sdp_utils.IntBound,
    params: ModelParams,
    method: str,
) -> List[sdp_utils.IntBound]:
    """Compute bounds with jax_verify."""

    jv_input_bound = jax_verify.IntervalBound(input_bound.lb, input_bound.ub)

    # create a function that takes as arguments the input and all parameters
    # that have bounds (as specified in param_bounds) and returns all
    # activations
    all_act_fun = _make_all_act_fn(params)

    # use jax_verify to perform (bilinear) interval bound propagation
    jv_param_bounds = [(p.w_bound, p.b_bound) for p in params if p.has_bounds]

    if method == 'ibp':
        _, jv_bounds = jax_verify.interval_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    elif method == 'fastlin':
        _, jv_bounds = jax_verify.forward_fastlin_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    elif method == 'ibpfastlin':
        _, jv_bounds = jax_verify.ibpforwardfastlin_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    elif method == 'crown':
        _, jv_bounds = jax_verify.backward_crown_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    elif method == 'nonconvex':
        _, jv_bounds = jax_verify.nonconvex_constopt_bound_propagation(
            all_act_fun, jv_input_bound, *jv_param_bounds)
    else:
        raise ValueError('Unsupported method.')

    # re-format bounds with internal convention
    bounds = []
    for intermediate_bound in jv_bounds:
        bounds.append(
            sdp_utils.IntBound(lb_pre=intermediate_bound.lower,
                               ub_pre=intermediate_bound.upper,
                               lb=jnp.maximum(intermediate_bound.lower, 0),
                               ub=jnp.maximum(intermediate_bound.upper, 0)))

    return bounds
예제 #7
0
def _nonconvex_boundprop(params,
                         x,
                         epsilon,
                         input_bounds,
                         nonconvex_boundprop_steps=100,
                         nonconvex_boundprop_nodes=128):
    """Wrapper for nonconvex bound propagation."""
    # Get initial bounds for boundprop
    init_bounds = utils.init_bound(x,
                                   epsilon,
                                   input_bounds=input_bounds,
                                   add_batch_dim=False)

    # Build fn to boundprop through
    all_act_fun = functools.partial(utils.predict_cnn,
                                    params,
                                    include_preactivations=True)

    # Collect the intermediate bounds.
    input_bound = jax_verify.IntervalBound(init_bounds.lb, init_bounds.ub)

    optimizer = optimizers.OptimizingConcretizer(
        FistaOptimizer(num_steps=nonconvex_boundprop_steps),
        max_parallel_nodes=nonconvex_boundprop_nodes)
    nonconvex_algorithm = nonconvex.nonconvex_algorithm(
        duals.WolfeNonConvexBound, optimizer)

    all_outputs, _ = bound_propagation.bound_propagation(
        nonconvex_algorithm, all_act_fun, input_bound)
    _, intermediate_nonconvex_bounds = all_outputs

    bounds = [init_bounds]
    for nncvx_bound in intermediate_nonconvex_bounds:
        bounds.append(
            utils.IntBound(lb_pre=nncvx_bound.lower,
                           ub_pre=nncvx_bound.upper,
                           lb=jnp.maximum(nncvx_bound.lower, 0),
                           ub=jnp.maximum(nncvx_bound.upper, 0)))
    return bounds