def compare_attacks(key, item, targeted=False):
    AdvertorchAttack = key
    CleverhansAttack = item["cl_class"]
    cl_kwargs = merge2dicts(item["kwargs"], item["cl_kwargs"])
    at_kwargs = merge2dicts(item["kwargs"], item["at_kwargs"])
    thresholds = item["thresholds"]

    # WARNING: don't use tf.InteractiveSession() here
    # It causes that fastfeature attack has to be the last test for some reason
    with tf.Session() as sess:
        model_pt = SimpleModel(DIM_INPUT, NUM_CLASS)
        model_tf = SimpleModelTf(DIM_INPUT, NUM_CLASS)
        model_tf.load_state_dict(model_pt.state_dict())
        adversary = AdvertorchAttack(model_pt, **at_kwargs)

        if AdvertorchAttack is FastFeatureAttack:
            model_tf_fastfeature = setup_simple_model(model_pt, inputs.shape)
            delta = np.random.uniform(
                -item["kwargs"]['eps'], item["kwargs"]['eps'],
                size=inputs.shape).astype('float32')
            inputs_guide = np.random.uniform(
                0, 1, size=(BATCH_SIZE, DIM_INPUT)).astype('float32')
            inputs_tf = tf.convert_to_tensor(inputs, np.float32)
            inputs_guide_tf = tf.convert_to_tensor(inputs_guide, np.float32)
            attack = CleverhansAttack(model_tf_fastfeature)
            cl_result = overwrite_fastfeature(attack,
                                              x=inputs_tf,
                                              g=inputs_guide_tf,
                                              eta=delta,
                                              **cl_kwargs)
            init = tf.global_variables_initializer()
            sess.run(init)
            ptb_cl = sess.run(cl_result) - inputs
            ptb_at = genenerate_ptb_pt(
                adversary, inputs, inputs_guide, delta=delta)

        else:
            attack = CleverhansAttack(model_tf, sess=sess)
            if targeted:
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    ptb_cl = attack.generate_np(
                        inputs, y_target=targets_onehot, **cl_kwargs) - inputs
                ptb_at = genenerate_ptb_pt(adversary, inputs, targets=targets)
            else:
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    ptb_cl = attack.generate_np(
                        inputs, y=None, **cl_kwargs) - inputs
                ptb_at = genenerate_ptb_pt(adversary, inputs, targets=None)

        if AdvertorchAttack is CarliniWagnerL2Attack:
            assert np.sum(np.abs(ptb_at)) > 0 and np.sum(np.abs(ptb_cl)) > 0, \
                ("Both advertorch and cleverhans returns zero perturbation"
                 " of CarliniWagnerL2Attack, "
                 "the test results are not reliable,"
                 " Adjust your testing parameters to avoid this."
                 )

        compare_at_cl(ptb_at, ptb_cl, **thresholds)
Ejemplo n.º 2
0
def _generate_models():
    mix_model = SimpleModel()
    mix_model.fc1.training = True
    mix_model.fc2.training = False
    mix_model.fc1.weight.requires_grad = False
    mix_model.fc2.bias.requires_grad = False

    trainon_model = SimpleModel()
    trainon_model.train()

    trainoff_model = SimpleModel()
    trainoff_model.eval()

    gradon_model = SimpleModel()

    gradoff_model = SimpleModel()
    set_param_grad_off(gradoff_model)

    return (mix_model, gradon_model, gradoff_model, trainon_model,
            trainoff_model)