Exemple #1
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.set_random_seed(1234)

    with tf.Graph().as_default():
        num_batches = args.max_num_batches
        batch_size = args.batch_size

        device = '/gpu:0'
        config = tf.ConfigProto(log_device_placement=False)
        config.allow_soft_placement = True
        config.gpu_options.per_process_gpu_memory_fraction = 0.90
        config.gpu_options.allow_growth = True

        quantize = True
        slalom = not args.no_slalom
        blinded = args.blinding
        integrity = args.integrity
        simulate = args.simulate

        with tf.Session(config=config) as sess:

            with tf.device(device):
                model, model_info = get_model(args.model_name,
                                              batch_size,
                                              include_top=True,
                                              double_prec=False)

            dataset_images, labels = imagenet.load_validation(
                args.input_dir,
                batch_size,
                preprocess=model_info['preprocess'])

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            #sgxutils = SGXDNNUtils(args.use_sgx, num_enclaves=batch_size)
            sgxutils = SGXDNNUtils(args.use_sgx, num_enclaves=1)

            num_linear_layers = len(get_all_linear_layers(model))
            if blinded and not simulate:
                queues = [
                    tf.FIFOQueue(capacity=num_batches + 1, dtypes=[tf.float32])
                    for _ in range(num_linear_layers)
                ]
            else:
                queues = None

            model, linear_ops_in, linear_ops_out = transform(
                model,
                log=False,
                quantize=quantize,
                verif_preproc=True,
                slalom=slalom,
                slalom_integrity=integrity,
                slalom_privacy=blinded,
                bits_w=model_info['bits_w'],
                bits_x=model_info['bits_x'],
                sgxutils=sgxutils,
                queues=queues)

            dtype = np.float32
            model_json, weights = model_to_json(sess,
                                                model,
                                                dtype=dtype,
                                                verif_preproc=True,
                                                slalom_privacy=blinded,
                                                bits_w=model_info['bits_w'],
                                                bits_x=model_info['bits_x'])
            sgxutils.load_model(model_json,
                                weights,
                                dtype=dtype,
                                verify=True,
                                verify_preproc=True)

            num_classes = np.prod(model.output.get_shape().as_list()[1:])
            print("num_classes: {}".format(num_classes))

            print_acc = (num_classes == 1000)
            res = Results(acc=print_acc)
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()

            sgxutils.slalom_init(integrity, (blinded and not simulate),
                                 batch_size)
            if blinded and not simulate:
                in_ph, zs, out_ph, queue_ops, temps, out_funcs = build_blinding_ops(
                    model, queues, batch_size)

            for i in range(num_batches):
                images, true_labels = sess.run([dataset_images, labels])

                if quantize:
                    images = np.round(2**model_info['bits_x'] * images).astype(
                        np.float32)
                    print("input images: {}".format(np.sum(np.abs(images))))

                if blinded and not simulate:
                    prepare_blinding_factors(
                        sess,
                        model,
                        sgxutils,
                        in_ph,
                        zs,
                        out_ph,
                        queue_ops,
                        batch_size,
                        num_batches=1,
                        #inputs=images, temps=temps, out_funcs=out_funcs
                    )

                images = sgxutils.slalom_blind_input(images)
                print("blinded images: {}".format(
                    (np.min(images), np.max(images),
                     np.sum(np.abs(images.astype(np.float64))))))
                print(images.reshape(-1)[:3], images.reshape(-1)[-3:])

                res.start_timer()

                preds = sess.run(model.outputs[0],
                                 feed_dict={
                                     model.inputs[0]: images,
                                     backend.learning_phase(): 0
                                 },
                                 options=run_options,
                                 run_metadata=run_metadata)

                preds = np.reshape(preds, (batch_size, -1))
                res.end_timer(size=len(images))
                res.record_acc(preds, true_labels)
                res.print_results()
                tl = timeline.Timeline(run_metadata.step_stats)
                ctf = tl.generate_chrome_trace_format()
                with open(
                        'timeline_{}_{}.json'.format(args.model_name,
                                                     device[1:4]), 'w') as f:
                    f.write(ctf)

                sys.stdout.flush()
            coord.request_stop()
            coord.join(threads)

        if sgxutils is not None:
            sgxutils.destroy()
Exemple #2
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.set_random_seed(1234)

    with tf.Graph().as_default():
        # Prepare graph
        num_batches = args.max_num_batches

        sgxutils = None

        if args.mode == 'tf-gpu':
            assert not args.use_sgx

            device = '/gpu:0'
            config = tf.ConfigProto(log_device_placement=False)
            config.allow_soft_placement = True
            config.gpu_options.per_process_gpu_memory_fraction = 0.90
            config.gpu_options.allow_growth = True

        elif args.mode == 'tf-cpu':
            assert not args.verify and not args.use_sgx

            device = '/gpu:0'
            #config = tf.ConfigProto(log_device_placement=False)
            config = tf.ConfigProto(log_device_placement=False,
                                    device_count={
                                        'CPU': 1,
                                        'GPU': 0
                                    })
            config.intra_op_parallelism_threads = 1
            config.inter_op_parallelism_threads = 1

        else:
            assert args.mode == 'sgxdnn'

            device = '/gpu:0'
            config = tf.ConfigProto(log_device_placement=False)
            config.allow_soft_placement = True
            config.gpu_options.per_process_gpu_memory_fraction = 0.9
            config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:
            device = '/gpu:0'
            with tf.device(device):
                model, model_info = get_model(args.model_name,
                                              args.batch_size,
                                              include_top=not args.no_top)
                partition_at_layer = 11
                #print_summary(model_full, line_length=None, positions=None, print_fn=None)
                #sys.exit(0)
                get_final_layer_model_output = backend.function(
                    [model.input], [model.layers[-1].output])

            dataset_images, labels = imagenet.load_validation(
                args.input_dir,
                args.batch_size,
                preprocess=model_info['preprocess'],
                num_preprocessing_threads=1)

            model, linear_ops_in, linear_ops_out = transform(
                model,
                log=False,
                quantize=args.verify,
                verif_preproc=args.preproc,
                bits_w=model_info['bits_w'],
                bits_x=model_info['bits_x'])

            if args.mode == 'sgxdnn':
                #sgxutils = SGXDNNUtils(args.use_sgx, num_enclaves=args.batch_size)
                #sgxutils = SGXDNNUtils(args.use_sgx, num_enclaves=2)
                #sgxutils = SGXDNNUtils(args.use_sgx)
                partitionAtLayer = 18
                dtype = np.float32 if not args.verify else DTYPE_VERIFY
                #model_json_part1, weights_part1, model_json_part2, weights_part2 = model_to_json(sess, model, args.preproc, dtype=dtype, bits_w=model_info['bits_w'], bits_x=model_info['bits_x'], partitionAtLayer=partitionAtLayer)
                #sgxutils.load_model(model_json_part1, weights_part1, dtype=dtype, verify=args.verify, verify_preproc=args.preproc)

            num_classes = np.prod(model.output.get_shape().as_list()[1:])
            print("num_classes: {}".format(num_classes))
            num_inter_feats = [28, 28, 256]
            print(
                "number of intermediate features: {}".format(num_inter_feats))

            print_acc = (num_classes == 1000)
            res = Results(acc=print_acc)
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()

            #from multiprocessing.dummy import Pool as ThreadPool
            #pool = ThreadPool(3)

            for i in range(num_batches):
                images, true_labels = sess.run([dataset_images, labels])

                if args.verify:
                    images = np.round(2**model_info['bits_x'] * images)
                print("input images: {}".format(np.sum(np.abs(images))))

                if args.mode in ['tf-gpu', 'tf-cpu']:
                    res.start_timer()
                    preds = sess.run(model.outputs[0],
                                     feed_dict={
                                         model.inputs[0]: images,
                                         backend.learning_phase(): 0
                                     },
                                     options=run_options,
                                     run_metadata=run_metadata)

                    print(np.sum(np.abs(images)), np.sum(np.abs(preds)))
                    preds = np.reshape(preds, (args.batch_size, num_classes))
                    res.end_timer(size=len(images))
                    res.record_acc(preds, true_labels)
                    res.print_results()

                    tl = timeline.Timeline(run_metadata.step_stats)
                    ctf = tl.generate_chrome_trace_format()
                    with open('timeline.json', 'w') as f:
                        f.write(ctf)

                else:
                    res.start_timer()

                    if args.verify:

                        t1 = timer()
                        linear_outputs = sess.run(linear_ops_out,
                                                  feed_dict={
                                                      model.inputs[0]: images,
                                                      backend.learning_phase():
                                                      0
                                                  },
                                                  options=run_options,
                                                  run_metadata=run_metadata)
                        t2 = timer()
                        print("GPU compute time: {:.4f}".format(
                            (t2 - t1) / (1.0 * args.batch_size)))

                        #mod_test(sess, model, images, linear_ops_in, linear_ops_out, verif_preproc=args.preproc)

                        def func(data):
                            return sgxutils.predict_and_verify(
                                data[1],
                                data[2],
                                num_classes=num_classes,
                                dtype=dtype,
                                eid_idx=0)

                        if not args.verify_batched:
                            start = timer()

                            linear_outputs_batched = [
                                x.reshape((args.batch_size, -1))
                                for x in linear_outputs
                            ]

                            preds = []
                            for i in range(args.batch_size):
                                t1 = timer()
                                aux_data = [
                                    x[i] for x in linear_outputs_batched
                                ]
                                pred = sgxutils.predict_and_verify(
                                    images[i:i + 1],
                                    aux_data,
                                    num_classes=num_classes,
                                    dtype=dtype)
                                t2 = timer()
                                print("verify time: {:.4f}".format((t2 - t1)))
                                preds.append(pred)
                            preds = np.vstack(preds)
                            end = timer()
                            print("avg verify time: {:.4f}".format(
                                (end - start) / (1.0 * args.batch_size)))

                            #all_data = [(i, images[i:i+1], [x[i] for x in linear_outputs_batched]) for i in range(args.batch_size)]
                            #preds = np.vstack(pool.map(func, all_data))
                        else:
                            preds = sgxutils.predict_and_verify(
                                images,
                                linear_outputs,
                                num_classes=num_classes,
                                dtype=dtype)

                    else:

                        def func(data):
                            return sgxutils.predict(data[1],
                                                    num_classes=num_classes,
                                                    eid_idx=0)

                        #all_data = [(i, images[i:i+1]) for i in range(args.batch_size)]
                        #preds = np.vstack(pool.map(func, all_data))
                        with tf.device(device):
                            intermediate_feats = np.zeros((4, 224, 224, 3))
                        preds = []
                        for i in range(args.batch_size):
                            # Overloading num_classes variable to collect intermediate feature maps
                            # Predict partition 1 in enclave
                            #intermediate_feats = sgxutils.predict(images[i:i + 1], num_classes=0, is_intermediate=True, num_inter_feats=num_inter_feats)
                            print("intermediate features shape: {}".format(
                                intermediate_feats.shape))
                            #print("intermediate features: {}".format(intermediate_feats))
                            # Send data to gpu
                            start_inter_time = time.time()
                            #device = '/cpu:0'
                            with tf.device(device):
                                # Predict partition 2 outside enclave, in GPU
                                pred = get_final_layer_model_output(
                                    [intermediate_feats])[0]
                            print("pred shape: {}".format(pred.shape))
                            print("GPU compute time: {}".format(
                                time.time() - start_inter_time))

                            preds.append(pred)
                        preds = np.vstack(preds)

                    #res.end_timer(size=len(images))
                    #res.record_acc(preds, true_labels)
                    #res.print_results()

                    #tl = timeline.Timeline(run_metadata.step_stats)
                    #ctf = tl.generate_chrome_trace_format()
                    #ctf_j = json.loads(ctf)
                    #events = [e["ts"] for e in ctf_j["traceEvents"] if "ts" in e]
                    #print("TF Timeline: {:.4f}".format((np.max(events) - np.min(events)) / (1000000.0 * args.batch_size)))

                sys.stdout.flush()
            coord.request_stop()
            coord.join(threads)

        if sgxutils is not None:
            sgxutils.destroy()
Exemple #3
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.set_random_seed(1234)

    with tf.Graph().as_default():
        # Prepare graph
        num_batches = args.max_num_batches

        sgxutils = None

        if args.mode == 'tf-gpu':
            assert not args.use_sgx

            device = '/gpu:0'
            config = tf.ConfigProto(log_device_placement=False)
            config.allow_soft_placement = True
            config.gpu_options.per_process_gpu_memory_fraction = 0.90
            config.gpu_options.allow_growth = True

        elif args.mode == 'tf-cpu':
            assert not args.verify and not args.use_sgx

            device = '/cpu:0'
            # config = tf.ConfigProto(log_device_placement=False)
            config = tf.ConfigProto(log_device_placement=False,
                                    device_count={
                                        'CPU': 1,
                                        'GPU': 0
                                    })
            config.intra_op_parallelism_threads = 1
            config.inter_op_parallelism_threads = 1

        else:
            assert args.mode == 'sgxdnn'

            device = '/gpu:0'
            config = tf.ConfigProto(log_device_placement=False)
            config.allow_soft_placement = True
            config.gpu_options.per_process_gpu_memory_fraction = 0.9
            config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:
            with tf.device(device):
                # model, model_info = get_model(args.model_name, args.batch_size, include_top=not args.no_top)
                model, model_info = get_test_model(args.batch_size)
            model_copy = model
            model, linear_ops_in, linear_ops_out = transform(
                model,
                log=False,
                quantize=args.verify,
                verif_preproc=args.preproc,
                bits_w=model_info['bits_w'],
                bits_x=model_info['bits_x'])
            # dataset_images, labels = imagenet.load_validation(args.input_dir, args.batch_size,
            #                                                 preprocess=model_info['preprocess'],
            #                                              num_preprocessing_threads=1)

            if args.mode == 'sgxdnn':
                # check weight equal or not
                # sgxutils = SGXDNNUtils(args.use_sgx, num_enclaves=args.batch_size)
                # sgxutils = SGXDNNUtils(args.use_sgx, num_enclaves=2)
                sgxutils = SGXDNNUtils(args.use_sgx)

                dtype = np.float32 if not args.verify else DTYPE_VERIFY
                model_json, weights = model_to_json(
                    sess,
                    model,
                    args.preproc,
                    dtype=dtype,
                    bits_w=model_info['bits_w'],
                    bits_x=model_info['bits_x'])
                sgxutils.load_model(model_json,
                                    weights,
                                    dtype=dtype,
                                    verify=args.verify,
                                    verify_preproc=args.preproc)

            num_classes = np.prod(model.output.get_shape().as_list()[1:])
            print("num_classes: {}".format(num_classes))

            print_acc = (num_classes == 10)
            res = Results(acc=print_acc)
            coord = tf.train.Coordinator()
            init = tf.initialize_all_variables()
            # sess.run(init)
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()

            # from multiprocessing.dummy import Pool as ThreadPool
            # pool = ThreadPool(3)

            (X_train, y_train), (X_test, y_test) = cifar10.load_data()
            y_train = y_train.reshape(y_train.shape[0])
            y_test = y_test.reshape(y_test.shape[0])
            X_train = X_train.astype('float32')
            X_test = X_test.astype('float32')
            X_train /= 255
            X_test /= 255
            y_train = to_categorical(y_train, num_classes)
            y_test = to_categorical(y_test, num_classes)

            num_batches = int(X_train.shape[0] / args.batch_size)
            print('training batch number :{}'.format(num_batches))
            lr = 0.001
            for k in range(args.epoch):
                if (k + 1) % 10:
                    lr *= 0.95
                print('Epoch {}/{}'.format(k + 1, args.epoch))
                for i in range(num_batches):
                    done_number = int(30 * (i + 1) / num_batches)
                    wait_to_be_done = 30 - done_number
                    print("\r{}/{} [{}>{}] {:.2f}% ".format(
                        (i + 1) * args.batch_size, X_train.shape[0],
                        '=' * done_number, '.' * wait_to_be_done,
                        100 * (i + 1) / num_batches),
                          end='')
                    images = X_train[(i * args.batch_size):((i + 1) *
                                                            args.batch_size)]
                    labels = y_train[(i * args.batch_size):((i + 1) *
                                                            args.batch_size)]
                    if args.train:
                        loss_batch, acc_batch = sgxutils.train(
                            images,
                            labels,
                            num_classes=num_classes,
                            learn_rate=lr)
                        print(' - loss :{:.4f} - acc :{:.4f}'.format(
                            loss_batch, acc_batch),
                              end='')
                sys.stdout.flush()
            #        res.start_timer()

            #        # no verify
            #        def func(data):
            #            return sgxutils.predict(data[1], num_classes=num_classes, eid_idx=0)

            #        def get_gradient(model_copy,layer_index,images):
            #           # 下面是求出layer层导数,用来debug
            #           # layer = model_copy.layers[layer_index+1 if layer_index>0 else layer_index]
            #           layer = model_copy.layers[layer_index]
            #           print(layer.name)
            #           grad = model_copy.optimizer.get_gradients(model_copy.total_loss,layer.output)
            #           input_tensors = [model_copy.inputs[0], # input data
            #                            model_copy.sample_weights[0], # how much to weight each sample by
            #                            model_copy.targets[0], # labels
            #                            K.learning_phase(), # train or test mode
            #                            ]
            #           get_gradients = K.function(inputs=input_tensors, outputs=grad)
            #           inputs = [images, # X
            #                     np.ones(args.batch_size), # sample weights
            #                     labels, # y
            #                     0 # learning phase in TEST mode
            #                     ]
            #           grad = get_gradients(inputs)[0]
            #           return grad
            # images = np.random.random((200, 32, 32, 3))
            # labels = np.zeros((200, 10))
            # for i in range(200):
            #     index = np.random.randint(0, 10)
            #     labels[i][index] = 1
            model_copy.fit(X_train, y_train, batch_size=32, epochs=1)
            coord.request_stop()
            coord.join(threads)
    if sgxutils is not None:
        sgxutils.destroy()