Exemplo n.º 1
0
def compare_attacks(key, item):
    AdvertorchAttack = key
    fmodel = foolbox.models.PyTorchModel(model,
                                         bounds=(0, 1),
                                         num_classes=NUM_CLASS)
    fb_adversary = item["fb_class"](fmodel)
    fb_kwargs = merge2dicts(item["kwargs"], item["fb_kwargs"])
    at_kwargs = merge2dicts(item["kwargs"], item["at_kwargs"])
    thresholds = item["thresholds"]
    at_adversary = AdvertorchAttack(model, **at_kwargs)
    x_at = at_adversary.perturb(img_batch, label_batch)
    y_logits = model(img_batch)
    y_at_logits = model(x_at)
    y_pred = predict_from_logits(y_logits)
    y_at_pred = predict_from_logits(y_at_logits)

    fb_successed_once = False
    for i, (x_i, y_i) in enumerate(zip(img_batch, label_batch)):
        # rule out when classification is wrong or attack is
        # unsuccessful (we test if foolbox attacks fails here)
        if y_i != y_pred[i:i + 1][0]:
            continue
        if y_i == y_at_pred[i:i + 1][0]:
            continue
        np.random.seed(233333)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            x_fb = fb_adversary(x_i.cpu().numpy(), label=int(y_i), **fb_kwargs)
        if x_fb is not None:
            compare_at_fb(x_at[i].cpu().numpy(), x_fb, **thresholds)
            fb_successed_once = True

    if not fb_successed_once:
        raise RuntimeError(
            "Foolbox never succeed, change your testing parameters!!!")
def compare_attacks(key, item, targeted=False):
    AdvertorchAttack = key
    CleverhansAttack = item["cl_class"]
    cl_kwargs = merge2dicts(item["kwargs"], item["cl_kwargs"])
    at_kwargs = merge2dicts(item["kwargs"], item["at_kwargs"])
    thresholds = item["thresholds"]

    # WARNING: don't use tf.InteractiveSession() here
    # It causes that fastfeature attack has to be the last test for some reason
    with tf.Session() as sess:
        model_pt = SimpleModel(DIM_INPUT, NUM_CLASS)
        model_tf = SimpleModelTf(DIM_INPUT, NUM_CLASS)
        model_tf.load_state_dict(model_pt.state_dict())
        adversary = AdvertorchAttack(model_pt, **at_kwargs)

        if AdvertorchAttack is FastFeatureAttack:
            model_tf_fastfeature = setup_simple_model(model_pt, inputs.shape)
            delta = np.random.uniform(
                -item["kwargs"]['eps'], item["kwargs"]['eps'],
                size=inputs.shape).astype('float32')
            inputs_guide = np.random.uniform(
                0, 1, size=(BATCH_SIZE, DIM_INPUT)).astype('float32')
            inputs_tf = tf.convert_to_tensor(inputs, np.float32)
            inputs_guide_tf = tf.convert_to_tensor(inputs_guide, np.float32)
            attack = CleverhansAttack(model_tf_fastfeature)
            cl_result = overwrite_fastfeature(attack,
                                              x=inputs_tf,
                                              g=inputs_guide_tf,
                                              eta=delta,
                                              **cl_kwargs)
            init = tf.global_variables_initializer()
            sess.run(init)
            ptb_cl = sess.run(cl_result) - inputs
            ptb_at = genenerate_ptb_pt(
                adversary, inputs, inputs_guide, delta=delta)

        else:
            attack = CleverhansAttack(model_tf, sess=sess)
            if targeted:
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    ptb_cl = attack.generate_np(
                        inputs, y_target=targets_onehot, **cl_kwargs) - inputs
                ptb_at = genenerate_ptb_pt(adversary, inputs, targets=targets)
            else:
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    ptb_cl = attack.generate_np(
                        inputs, y=None, **cl_kwargs) - inputs
                ptb_at = genenerate_ptb_pt(adversary, inputs, targets=None)

        if AdvertorchAttack is CarliniWagnerL2Attack:
            assert np.sum(np.abs(ptb_at)) > 0 and np.sum(np.abs(ptb_cl)) > 0, \
                ("Both advertorch and cleverhans returns zero perturbation"
                 " of CarliniWagnerL2Attack, "
                 "the test results are not reliable,"
                 " Adjust your testing parameters to avoid this."
                 )

        compare_at_cl(ptb_at, ptb_cl, **thresholds)