def boundprop(params, x, epsilon, input_bounds): assert len( x.shape) == 4 and x.shape[0] == 1, f'shape check {x.shape}' init_bound = utils.init_bound(x[0], epsilon, input_bounds=input_bounds) return utils.boundprop(params, init_bound)
def _compute_standard_bounds( x: Tensor, epsilon: float, input_bounds: Sequence[int], params: ModelParams, config: Union[ConfigDict, Dict[str, Any]], ): """Perform bound-propagation and return bounds. Args: x: input to the model under verification. epsilon: radius of l-infinity ball around x. input_bounds: feasibility bounds of inputs (e.g. [0, 1]). params: parameters of the model under verification. config: experiment ConfigDict. Returns: List of bounds per layer, including the input bounds as the first element. """ for param in params: if param.has_bounds: raise ValueError('Unsupported bilinear bound propagation.') if config['boundprop_type'] == 'nonconvex': bounds = boundprop_utils.boundprop( params, jnp.expand_dims(x, axis=0), epsilon, input_bounds, 'nonconvex', nonconvex_boundprop_steps=config['nonconvex_boundprop_steps'], nonconvex_boundprop_nodes=config['nonconvex_boundprop_nodes'], ) elif config['boundprop_type'] == 'crown_ibp': bounds = boundprop_utils.boundprop(params, jnp.expand_dims(x, axis=0), epsilon, input_bounds, 'crown_ibp') else: # initial bounds for boundprop init_bounds = sdp_utils.init_bound(x, epsilon, input_bounds=input_bounds) bounds = [init_bounds] + _compute_jv_bounds( input_bound=init_bounds, params=params, method=config['boundprop_type']) return bounds
def solve_with_functional_lagrangian(self): config = get_config() init_bound = sdp_utils.init_bound(self.inputs[0], self.eps, input_bounds=self.input_bounds) bounds = sdp_utils.boundprop( self.network_params + [(self.objective, self.objective_bias)], init_bound) logits_fn = make_model_fn(self.network_params) def spec_fn(inputs): return jnp.matmul(logits_fn(inputs), self.objective) + self.objective_bias input_bounds = jax_verify.IntervalBound(bounds[0].lb, bounds[0].ub) lagrangian_form_per_layer = lagrangian_form.Linear() lagrangian_form_per_layer = [ lagrangian_form_per_layer for bd in bounds ] inner_opt = lp.LpStrategy() env, dual_params, dual_params_types = inner_opt.init_duals( jax_verify.ibp_transform, verify_utils.SpecType.ADVERSARIAL, False, spec_fn, self.keys[3], lagrangian_form_per_layer, input_bounds) opt, num_steps = dual_build.make_opt_and_num_steps(config.outer_opt) dual_state = ml_collections.ConfigDict(type_safe=False) dual_solve.solve_dual_train( env, key=self.keys[4], num_steps=num_steps, opt=opt, dual_params=dual_params, dual_params_types=dual_params_types, dual_state=dual_state, affine_before_relu=False, spec_type=verify_utils.SpecType.ADVERSARIAL, inner_opt=inner_opt, logger=(lambda *args: None), ) return dual_state.loss
def get_verif_instance(params, x, label, target_label, epsilon, input_bounds=(0., 1.)): """Creates verif instance.""" if FLAGS.boundprop_type == 'ibp': bounds = utils.boundprop( params, utils.init_bound(x, epsilon, input_bounds=input_bounds)) else: bounds = boundprop_utils.boundprop(params, np.expand_dims(x, axis=0), epsilon, input_bounds, FLAGS.boundprop_type) verif_instance = utils.make_relu_robust_verif_instance( params, bounds, target_label=target_label, label=label, input_bounds=input_bounds) return verif_instance
def _nonconvex_boundprop(params, x, epsilon, input_bounds, nonconvex_boundprop_steps=100, nonconvex_boundprop_nodes=128): """Wrapper for nonconvex bound propagation.""" # Get initial bounds for boundprop init_bounds = utils.init_bound(x, epsilon, input_bounds=input_bounds, add_batch_dim=False) # Build fn to boundprop through all_act_fun = functools.partial(utils.predict_cnn, params, include_preactivations=True) # Collect the intermediate bounds. input_bound = jax_verify.IntervalBound(init_bounds.lb, init_bounds.ub) optimizer = optimizers.OptimizingConcretizer( FistaOptimizer(num_steps=nonconvex_boundprop_steps), max_parallel_nodes=nonconvex_boundprop_nodes) nonconvex_algorithm = nonconvex.nonconvex_algorithm( duals.WolfeNonConvexBound, optimizer) all_outputs, _ = bound_propagation.bound_propagation( nonconvex_algorithm, all_act_fun, input_bound) _, intermediate_nonconvex_bounds = all_outputs bounds = [init_bounds] for nncvx_bound in intermediate_nonconvex_bounds: bounds.append( utils.IntBound(lb_pre=nncvx_bound.lower, ub_pre=nncvx_bound.upper, lb=jnp.maximum(nncvx_bound.lower, 0), ub=jnp.maximum(nncvx_bound.upper, 0))) return bounds
def get_bounds( x: Tensor, epsilon: float, input_bounds: Sequence[int], params: ModelParams, config: Union[ConfigDict, Dict[str, Any]], ) -> List[sdp_utils.IntBound]: """Perform bound-propagation and return bounds. The code assumes that the sequential model can be split into two parts. The first part (potentially empty) does not contain any bound on the parameters and can thus use boundprop as usual. The second part (potentially empty) contains parameter bounds and thus employs a method that supports bilinear bound propagation. Args: x: input to the model under verification. epsilon: radius of l-infinity ball around x. input_bounds: feasibility bounds of inputs (e.g. [0, 1]). params: parameters of the model under verification. config: experiment ConfigDict. Returns: List of bounds per layer, including the input bounds as the first element. """ if config['boundprop_type'] != config['bilinear_boundprop_type']: # when using a different boundprop method for bilinear operations, partition # parameters used for "standard" boundprop vs bilinear boundprop first_idx_with_param_bounds = 0 for param in params: if param.has_bounds: break first_idx_with_param_bounds += 1 params_standard_boundprop, params_bilinear_boundprop = ( params[:first_idx_with_param_bounds], params[first_idx_with_param_bounds:]) else: params_standard_boundprop = [] params_bilinear_boundprop = params if params_standard_boundprop: bounds_standard = _compute_standard_bounds( x=x, epsilon=epsilon, input_bounds=input_bounds, params=params_standard_boundprop, config=config, ) else: bounds_standard = [ sdp_utils.init_bound(x, epsilon, input_bounds=input_bounds) ] if params_bilinear_boundprop: bounds_bilinear = _compute_jv_bounds( input_bound=bounds_standard[-1], params=params_bilinear_boundprop, method=config['bilinear_boundprop_type'], ) else: bounds_bilinear = [] return bounds_standard + bounds_bilinear