Exemple #1
0
def load_uci_update(role, ip, server, port, mpc_model_dir, mpc_model_filename, updated_model_dir):
    """
    Load, update and save uci MPC model.

    """
    place = fluid.CPUPlace()
    exe = fluid.Executor(place)

    # Step 1. initialize MPC environment and load MPC model into default_main_program to update.
    pfl_mpc.init("aby3", role, ip, server, port)
    aby3.load_mpc_model(exe=exe,
                        mpc_model_dir=mpc_model_dir,
                        mpc_model_filename=mpc_model_filename)

    # Step 2. MPC update
    epoch_num = network.MPC_UPDATE_EPOCH
    batch_size = network.BATCH_SIZE
    mpc_data_dir = "../mpc_data/"
    feature_file = mpc_data_dir + "house_feature"
    feature_shape = (13,)
    label_file = mpc_data_dir + "house_label"
    label_shape = (1,)
    loss_file = "./tmp/uci_mpc_loss.part{}".format(role)
    if os.path.exists(loss_file):
        os.remove(loss_file)
    updated_model_name = 'mpc_updated_model'
    feature_name = 'x'
    label_name = 'y'
    # fetch loss if needed
    loss = fluid.default_main_program().global_block().var('mean_0.tmp_0')
    loader = process_data.get_mpc_dataloader(feature_file, label_file, feature_shape, label_shape,
                                         feature_name, label_name, role, batch_size)
    start_time = time.time()
    for epoch_id in range(epoch_num):
        step = 0
        for sample in loader():
            mpc_loss = exe.run(feed=sample, fetch_list=[loss.name])
            if step % 50 == 0:
                print('Epoch={}, Step={}, Loss={}'.format(epoch_id, step, mpc_loss))
                with open(loss_file, 'ab') as f:
                    f.write(np.array(mpc_loss).tostring())
                step += 1
    end_time = time.time()
    print('Mpc Updating of Epoch={} Batch_size={}, cost time in seconds:{}'
          .format(epoch_num, batch_size, (end_time - start_time)))

    # Step 3. save updated MPC model as a trainable model.
    aby3.save_trainable_model(exe=exe,
                              model_dir=updated_model_dir,
                              model_filename=updated_model_name)
    print('Successfully save mpc updated model into:{}'.format(updated_model_dir))
Exemple #2
0
def load_mpc_model_and_predict(role, ip, server, port, mpc_model_dir,
                               mpc_model_filename):
    """
    Predict based on MPC inference model, save prediction results into files.

    """
    place = fluid.CPUPlace()
    exe = fluid.Executor(place)

    # Step 1. initialize MPC environment and load MPC model to predict
    pfl_mpc.init("aby3", role, ip, server, port)
    infer_prog, feed_names, fetch_targets = aby3.load_mpc_model(
        exe=exe,
        mpc_model_dir=mpc_model_dir,
        mpc_model_filename=mpc_model_filename,
        inference=True)
    # Step 2. MPC predict
    batch_size = network.BATCH_SIZE
    feature_file = "/tmp/house_feature"
    feature_shape = (13, )
    pred_file = "./tmp/uci_prediction.part{}".format(role)
    loader = process_data.get_mpc_test_dataloader(feature_file, feature_shape,
                                                  role, batch_size)
    start_time = time.time()
    for sample in loader():
        prediction = exe.run(program=infer_prog,
                             feed={feed_names[0]: np.array(sample)},
                             fetch_list=fetch_targets)
        # Step 3. save prediction results
        with open(pred_file, 'ab') as f:
            f.write(np.array(prediction).tostring())
        break
    end_time = time.time()
    print('Mpc Predict with samples of {}, cost time in seconds:{}'.format(
        batch_size, (end_time - start_time)))
def infer(args):
    """
    infer
    """
    logger.info('Start inferring...')
    begin = time.time()
    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
    cur_model_path = os.path.join(args.model_dir, 'mpc_model', 'epoch_' + str(args.test_epoch),
                                  'checkpoint', 'party_{}'.format(args.role))

    with fluid.scope_guard(fluid.Scope()):
        pfl_mpc.init('aby3', args.role, 'localhost', args.server, args.port)
        infer_program, feed_target_names, fetch_vars = aby3.load_mpc_model(exe=exe,
                                                                    mpc_model_dir=cur_model_path,
                                                                    mpc_model_filename='__model__',
                                                                    inference=True)
        mpc_data_dir = args.mpc_data_dir
        user_vec_filepath = mpc_data_dir + 'user_vec'
        user_vec_part_filepath = user_vec_filepath + '.part{}'.format(args.role)

        sample_batch = args.batch_size
        watch_vecs = []
        search_vecs = []
        other_feats = []

        watch_vec_reader = read_share(file=mpc_data_dir + 'watch_vec', shape=(sample_batch, args.watch_vec_size))
        for vec in watch_vec_reader():
            watch_vecs.append(vec)
        search_vec_reader = read_share(file=mpc_data_dir + 'search_vec', shape=(sample_batch, args.search_vec_size))
        for vec in search_vec_reader():
            search_vecs.append(vec)
        other_feat_reader = read_share(file=mpc_data_dir + 'other_feat', shape=(sample_batch, args.other_feat_size))
        for vec in other_feat_reader():
            other_feats.append(vec)

        if os.path.exists(user_vec_part_filepath):
            os.system('rm -rf ' + user_vec_part_filepath)

        for i in range(args.batch_num):
            l3 = exe.run(infer_program,
                         feed={
                               'watch_vec': watch_vecs[i],
                               'search_vec': search_vecs[i],
                               'other_feat': other_feats[i],
                         },
                         return_numpy=True,
                         fetch_list=fetch_vars)

            with open(user_vec_part_filepath, 'a+') as f:
                f.write(np.array(l3[0]).tostring())


    end = time.time()
    logger.info('MPC inferring, cost_time: {:.5f}s'.format(end - begin))
    logger.info('End inferring.')
Exemple #4
0
def infer(test_loader, role, exe, BATCH_SIZE, mpc_model_dir,
          mpc_model_filename):
    """
    load mpc model and infer
    """
    # Load mpc model
    logger.info('Load model from {}'.format(mpc_model_dir))
    infer_program, feed_targets, fetch_targets = aby3.load_mpc_model(
        exe=exe,
        mpc_model_dir=mpc_model_dir,
        mpc_model_filename=mpc_model_filename,
        inference=True)

    # Infer
    logger.info('******************************************')
    logger.info('Start Inferring...')
    mpc_infer_data_dir = "./mpc_infer_data/"
    if not os.path.exists(mpc_infer_data_dir):
        try:
            os.mkdir(mpc_infer_data_dir)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise

    prediction_file = mpc_infer_data_dir + "prediction.part{}".format(role)
    if os.path.exists(prediction_file):
        os.remove(prediction_file)

    start_time = time.time()
    for sample in test_loader():
        prediction = exe.run(program=infer_program,
                             feed=sample,
                             fetch_list=fetch_targets)
        with open(prediction_file, 'ab') as f:
            f.write(np.array(prediction).tostring())
    end_time = time.time()
    logger.info('End Inferring...cost time: {}'.format(end_time - start_time))

    logger.info('Start Evaluate Accuracy...')
    cypher_file = mpc_infer_data_dir + "prediction"
    decrypt_file = mpc_infer_data_dir + 'label_mpc'
    time.sleep(0.1)
    if role == 0:
        if os.path.exists(decrypt_file):
            os.remove(decrypt_file)
        process_data.decrypt_data_to_file(cypher_file, (BATCH_SIZE, ),
                                          decrypt_file)
        evaluate.evaluate_accuracy('./mpc_infer_data/label_criteo',
                                   decrypt_file)
        evaluate.evaluate_auc('./mpc_infer_data/label_criteo', decrypt_file)

    end_time = time.time()
    logger.info('End Evaluate Accuracy...cost time: {}'.format(end_time -
                                                               start_time))
    logger.info('******************************************')