예제 #1
0
파일: app.py 프로젝트: mlx-bot/katalog
def robustness_api():
    try:
        s3_url = request.json['aws_endpoint_url']
        bucket_name = request.json['training_results_bucket']
        s3_username = request.json['aws_access_key_id']
        s3_password = request.json['aws_secret_access_key']
        model_id = request.json['model_id']
    except:
        abort(400)
    return json.dumps(robustness_check(s3_url, bucket_name, s3_username, s3_password, model_id))
예제 #2
0
    label_testset_path = args.label_testset_path
    data_bucket_name = args.data_bucket_name
    result_bucket_name = args.result_bucket_name
    clip_values = eval(args.clip_values)
    input_shape = eval(args.input_shape)

    object_storage_url = get_secret('/app/secrets/s3_url', 'minio-service:9000')
    object_storage_username = get_secret('/app/secrets/s3_access_key_id', 'minio')
    object_storage_password = get_secret('/app/secrets/s3_secret_access_key', 'minio123')

    metrics = robustness_check(object_storage_url, object_storage_username, object_storage_password,
                               data_bucket_name, result_bucket_name, model_id,
                               feature_testset_path=feature_testset_path,
                               label_testset_path=label_testset_path,
                               clip_values=clip_values,
                               nb_classes=nb_classes,
                               input_shape=input_shape,
                               model_class_file=model_class_file,
                               model_class_name=model_class_name,
                               LossFn=LossFn,
                               Optimizer=Optimizer,
                               epsilon=epsilon)

    if not os.path.exists(os.path.dirname(metric_path)):
        os.makedirs(os.path.dirname(metric_path))
    with open(metric_path, "w") as report:
        report.write(json.dumps(metrics))

    robust = "true"
    if metrics['model accuracy on adversarial samples'] < 0.2:
        robust = "false"
예제 #3
0
                        type=str,
                        help='Object storage bucket name')
    parser.add_argument('--s3_username',
                        type=str,
                        help='Object storage access key id')
    parser.add_argument('--s3_password',
                        type=str,
                        help='Object storage access key secret')
    parser.add_argument('--epsilon',
                        type=float,
                        help='Epsilon value for the FGSM attack')
    parser.add_argument('--model_id', type=str, help='Training model id')
    parser.add_argument('--metric_path',
                        type=str,
                        help='Path for robustness check output')
    args = parser.parse_args()

    s3_url = args.s3_url
    bucket_name = args.bucket_name
    s3_username = args.s3_username
    s3_password = args.s3_password
    epsilon = args.epsilon
    metric_path = args.metric_path
    model_id = args.model_id

    metrics = robustness_check(s3_url, bucket_name, s3_username, s3_password,
                               model_id, epsilon)

    with open(metric_path, "w") as report:
        report.write(json.dumps(metrics))