def compare_attacks(key, item, targeted=False): AdvertorchAttack = key CleverhansAttack = item["cl_class"] cl_kwargs = merge2dicts(item["kwargs"], item["cl_kwargs"]) at_kwargs = merge2dicts(item["kwargs"], item["at_kwargs"]) thresholds = item["thresholds"] # WARNING: don't use tf.InteractiveSession() here # It causes that fastfeature attack has to be the last test for some reason with tf.Session() as sess: model_pt = SimpleModel(DIM_INPUT, NUM_CLASS) model_tf = SimpleModelTf(DIM_INPUT, NUM_CLASS) model_tf.load_state_dict(model_pt.state_dict()) adversary = AdvertorchAttack(model_pt, **at_kwargs) if AdvertorchAttack is FastFeatureAttack: model_tf_fastfeature = setup_simple_model(model_pt, inputs.shape) delta = np.random.uniform( -item["kwargs"]['eps'], item["kwargs"]['eps'], size=inputs.shape).astype('float32') inputs_guide = np.random.uniform( 0, 1, size=(BATCH_SIZE, DIM_INPUT)).astype('float32') inputs_tf = tf.convert_to_tensor(inputs, np.float32) inputs_guide_tf = tf.convert_to_tensor(inputs_guide, np.float32) attack = CleverhansAttack(model_tf_fastfeature) cl_result = overwrite_fastfeature(attack, x=inputs_tf, g=inputs_guide_tf, eta=delta, **cl_kwargs) init = tf.global_variables_initializer() sess.run(init) ptb_cl = sess.run(cl_result) - inputs ptb_at = genenerate_ptb_pt( adversary, inputs, inputs_guide, delta=delta) else: attack = CleverhansAttack(model_tf, sess=sess) if targeted: with warnings.catch_warnings(): warnings.simplefilter("ignore") ptb_cl = attack.generate_np( inputs, y_target=targets_onehot, **cl_kwargs) - inputs ptb_at = genenerate_ptb_pt(adversary, inputs, targets=targets) else: with warnings.catch_warnings(): warnings.simplefilter("ignore") ptb_cl = attack.generate_np( inputs, y=None, **cl_kwargs) - inputs ptb_at = genenerate_ptb_pt(adversary, inputs, targets=None) if AdvertorchAttack is CarliniWagnerL2Attack: assert np.sum(np.abs(ptb_at)) > 0 and np.sum(np.abs(ptb_cl)) > 0, \ ("Both advertorch and cleverhans returns zero perturbation" " of CarliniWagnerL2Attack, " "the test results are not reliable," " Adjust your testing parameters to avoid this." ) compare_at_cl(ptb_at, ptb_cl, **thresholds)
def _generate_models(): mix_model = SimpleModel() mix_model.fc1.training = True mix_model.fc2.training = False mix_model.fc1.weight.requires_grad = False mix_model.fc2.bias.requires_grad = False trainon_model = SimpleModel() trainon_model.train() trainoff_model = SimpleModel() trainoff_model.eval() gradon_model = SimpleModel() gradoff_model = SimpleModel() set_param_grad_off(gradoff_model) return (mix_model, gradon_model, gradoff_model, trainon_model, trainoff_model)