예제 #1
0
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            loss = criterion(outputs, labels)
            loss_sum += loss.data.cpu().item() * images.size(0)
            correct += (predicted == labels).sum().item()
            cnt += int(images.size()[0])
        print('Accuracy of the network on the 10000 test images: %f %%' %
              (100 * correct / total))
        print("loss=", loss_sum / float(cnt))

    elapsed_time = time.time() - inference_start
    print("Elapsed time for Prediction", elapsed_time)


if __name__ == "__main__":
    _, _, _, test_to_run = argparser_distributed()
    print("====== New Tests ======")
    print("Test To run:", test_to_run)

    net_state_name, config_name = get_net_config_name(test_to_run)
    print(f"net_state going to load: {net_state_name}")
    print(f"store_configs going to load: {config_name}")
    store_configs = np.load(config_name, allow_pickle="TRUE").item()

    def input_shift(data):
        first_layer_name = "conv1"
        return data_shift(data, store_configs[first_layer_name + "ForwardX"])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
예제 #2
0
        p = Process(target=run_secure_nn_client_with_random_data,
                    args=[secure_nn, correctness_func, master_address, master_port])
        p.start()
        processes.append(p)
        for p in processes:
            p.join()

    if party == Config.server_rank:
        run_secure_nn_server_with_random_data(secure_nn, correctness_func, master_address, master_port)
    if party == Config.client_rank:
        run_secure_nn_client_with_random_data(secure_nn, correctness_func, master_address, master_port)

    print(f"\nTest for {test_name}: End")

if __name__ == "__main__":
    input_sid, master_addr, master_port, test_to_run = argparser_distributed()
    sys.stdout = Logger()

    print("====== New Tests ======")
    print("Test To run:", test_to_run)

    num_repeat = 5

    for _ in range(num_repeat):
        if test_to_run in ["small", "all"]:
            marshal_secure_nn_parties(input_sid, master_addr, master_port, generate_small_nn(), correctness_small_nn)
        if test_to_run in ["relu", "all"]:
            marshal_secure_nn_parties(input_sid, master_addr, master_port, generate_relu_only_nn(), correctness_relu_only_nn)
        if test_to_run in ["maxpool", "all"]:
            marshal_secure_nn_parties(input_sid, master_addr, master_port, generate_maxpool2x2(), correctness_maxpool2x2)
        if test_to_run in ["conv2d", "all"]: