Ejemplo n.º 1
0
def run_attack(args, summary):
    """Run the attack."""
    model = get_model(args.model, args.framework, summary)
    distance = get_distance(args.distance)
    criteria = get_criteria(args.criteria, args.target_class)

    channel_axis = 'channels_first' if args.framework == 'pytorch' else 'channels_last'
    image = get_image(args.image, args.framework, args.model, channel_axis)
    if args.target_class is None and args.framework != 'cloud':
        label = np.argmax(model.predictions(image))
    elif args.framework == 'cloud':
        label = model.predictions(image)
    else:
        label = args.target_class

    metric = get_metric(args.metric, model, criteria, distance)

    print(bcolors.BOLD + 'Process start' + bcolors.ENDC)
    if args.model in ["yolo_v3", "keras_ssd300", "retina_resnet_50"]:
        adversary = metric(image, unpack=False, binary_search_steps=1)
    elif args.metric not in ["carlini_wagner_l2", "carlini_wagner_linf"]:
        if args.metric in summary['verifiable_metrics']:
            adversary = metric(image, label, unpack=False, verify=args.verify)
        else:
            adversary = metric(image, label, unpack=False, epsilons=1000)
    else:
        adversary = metric(image,
                           label,
                           unpack=False,
                           binary_search_steps=10,
                           max_iterations=5)
    print(bcolors.BOLD + 'Process finished' + bcolors.ENDC)

    if adversary.image is None:
        print(bcolors.WARNING + 'Warning: Cannot find an adversary!' +
              bcolors.ENDC)
        return adversary

    ###################  print summary info  #####################################
    keywords = [args.framework, args.model, args.criteria, args.metric]

    if args.model not in ["yolo_v3", "keras_ssd300",
                          "retina_resnet_50"]:  # classification
        # interpret the label as human language
        with open('perceptron/utils/labels.txt') as info:
            imagenet_dict = eval(info.read())
        if args.framework != 'cloud':
            true_label = imagenet_dict[np.argmax(model.predictions(image))]
            fake_label = imagenet_dict[np.argmax(
                model.predictions(adversary.image))]
        else:
            print(args.framework)
            true_label = str(model.predictions(image))
            fake_label = str(model.predictions(image))

        print(bcolors.HEADER + bcolors.UNDERLINE + 'Summary:' + bcolors.ENDC)
        print('Configuration:' + bcolors.CYAN + ' --framework %s '
              '--model %s --criterion %s '
              '--metric %s' % tuple(keywords) + bcolors.ENDC)
        print('The predicted label of original image is ' + bcolors.GREEN +
              true_label + bcolors.ENDC)
        print('The predicted label of adversary image is ' + bcolors.RED +
              fake_label + bcolors.ENDC)
        print('Minimum perturbation required: %s' % bcolors.BLUE +
              str(adversary.distance) + bcolors.ENDC)
        if args.metric in [
                "brightness", "rotation", "horizontal_translation",
                "vertical_translation"
        ]:
            print('Verifiable bound: %s' % bcolors.BLUE +
                  str(adversary.verifiable_bounds) + bcolors.ENDC)
        print('\n')

        plot_image(adversary,
                   title=', '.join(keywords),
                   figname='examples/images/%s.png' % '_'.join(keywords))
    else:  # object detection
        print(bcolors.HEADER + bcolors.UNDERLINE + 'Summary:' + bcolors.ENDC)
        print('Configuration:' + bcolors.CYAN + ' --framework %s '
              '--model %s --criterion %s '
              '--metric %s' % tuple(keywords) + bcolors.ENDC)

        print('Minimum perturbation required: %s' % bcolors.BLUE +
              str(adversary.distance) + bcolors.ENDC)
        print('\n')

        plot_image_objectdetection(adversary,
                                   model,
                                   title=", ".join(keywords),
                                   figname='examples/images/%s.png' %
                                   '_'.join(keywords))
    if args.framework == 'keras':
        from perceptron.utils.func import clear_keras_session
        clear_keras_session()
print(bcolors.BOLD + 'Process start' + bcolors.ENDC)
adversary = metric(image, unpack=False)
print(bcolors.BOLD + 'Process finished' + bcolors.ENDC)

if adversary.image is None:
    print(bcolors.WARNING + 'Warning: Cannot find an adversary!' +
          bcolors.ENDC)
    exit(-1)

###################  print summary info  #####################################

keywords = ['Keras', 'SSD300', 'TargetClassMiss', 'BrightnessMetric']

print(bcolors.HEADER + bcolors.UNDERLINE + 'Summary:' + bcolors.ENDC)
print('Configuration:' + bcolors.CYAN + ' --framework %s '
      '--model %s --criterion %s '
      '--metric %s' % tuple(keywords) + bcolors.ENDC)

print('Minimum perturbation required: %s' % bcolors.BLUE +
      str(adversary.distance) + bcolors.ENDC)
print('\n')

# print the original image and the adversary
plot_image_objectdetection(adversary,
                           kmodel,
                           bounds=(0, 255),
                           title=", ".join(keywords),
                           figname='examples/images/%s.png' %
                           '_'.join(keywords))