コード例 #1
0
def validate_mpc(dataloader: DataLoader, model: crypten.nn.Module,
                 loss: crypten.nn.Module):
    model.eval()
    outs = []
    true_ys = []
    total_loss = None
    count = len(dataloader)
    for xs, ys in tqdm(dataloader, file=sys.stdout):
        out = model(xs)
        loss_val = loss(out, ys)

        outs.append(out)
        true_ys.append(ys)

        if total_loss is None:
            total_loss = loss_val.detach()
        else:
            total_loss += loss_val.detach()

    total_loss = total_loss.get_plain_text().item()

    all_out = crypten.cat(outs, dim=0)
    all_prob = all_out.sigmoid()
    all_prob = all_prob.get_plain_text()
    pred_ys = torch.where(all_prob > 0.5, 1, 0).tolist()
    pred_probs = all_prob.tolist()

    true_ys = crypten.cat(true_ys, dim=0)
    true_ys = true_ys.get_plain_text().tolist()

    return total_loss / count, precision_score(true_ys, pred_ys), recall_score(true_ys, pred_ys), \
           roc_auc_score(true_ys, pred_probs)
コード例 #2
0
def randn(*sizes, device=None):
    """
    Returns a tensor with normally distributed elements. Samples are
    generated using the Box-Muller transform with optimizations for
    numerical precision and MPC efficiency.
    """
    u = crypten.rand(*sizes, device=device).flatten()
    odd_numel = u.numel() % 2 == 1
    if odd_numel:
        u = crypten.cat([u, crypten.rand((1, ), device=device)])

    n = u.numel() // 2
    u1 = u[:n]
    u2 = u[n:]

    # Radius = sqrt(- 2 * log(u1))
    r2 = -2 * u1.log(input_in_01=True)
    r = r2.sqrt()

    # Theta = cos(2 * pi * u2) or sin(2 * pi * u2)
    cos, sin = u2.sub(0.5).mul(6.28318531).cossin()

    # Generating 2 independent normal random variables using
    x = r.mul(sin)
    y = r.mul(cos)
    z = crypten.cat([x, y])

    if odd_numel:
        z = z[1:]

    return z.view(*sizes)
コード例 #3
0
ファイル: test_autograd.py プロジェクト: xuleimath/CrypTen
        def test_case9(input, encr_input):
            intermediate1 = torch.cat([input, input])
            intermediate2 = intermediate1.mean(0, keepdim=True)
            output = torch.cat([intermediate2, intermediate1], dim=0).sum()

            encr_intermediate1 = crypten.cat([encr_input, encr_input])
            encr_intermediate2 = encr_intermediate1.mean(0, keepdim=True)
            encr_output = crypten.cat([encr_intermediate2, encr_intermediate1]).sum()

            return output, encr_output
コード例 #4
0
        def test_case8(input, encr_input):
            intermediate1 = input.add(3.0)
            intermediate2 = torch.cat([input, intermediate1])
            intermediate3 = intermediate2.pow(2.0)
            output = torch.cat([input, intermediate2,
                                intermediate3]).add(-1).sum()

            encr_intermediate1 = encr_input.add(3.0)
            encr_intermediate2 = crypten.cat([encr_input, encr_intermediate1])
            encr_intermediate3 = encr_intermediate2.pow(2.0)
            encr_output = (crypten.cat(
                [encr_input, encr_intermediate2,
                 encr_intermediate3]).add(-1).sum())

            return output, encr_output
コード例 #5
0
    def extend_row(tensor, dim, start_ind, end_ind):
        if reduction == "mean":
            extended_value = tensor.index_select(dim, torch.arange(start_ind, end_ind))
            extended_value = extended_value.mean(dim, keepdim=True)
        elif reduction == "max":
            extended_value = tensor.index_select(dim, torch.tensor(start_ind))
        else:
            raise ValueError(f"Invalid reduction {reduction} for adaptive pooling.")

        if start_ind == 0:
            return crypten.cat([extended_value, tensor], dim=dim)

        x = tensor.index_select(dim, torch.arange(start_ind))
        y = tensor.index_select(dim, torch.arange(start_ind, tensor.size(dim)))
        return crypten.cat([x, extended_value, y], dim=dim)
コード例 #6
0
ファイル: mpc.py プロジェクト: fionnoif/CrypTen
    def polynomial(self, coeffs, func="mul"):
        """Computes a polynomial function on a tensor with given coefficients,
        `coeffs`, that can be a list of values or a 1-D tensor.

        Coefficients should be ordered from the order 1 (linear) term first,
        ending with the highest order term. (Constant is not included).
        """
        # Coefficient input type-checking
        if isinstance(coeffs, list):
            coeffs = torch.tensor(coeffs)
        assert torch.is_tensor(coeffs) or crypten.is_encrypted_tensor(
            coeffs), "Polynomial coefficients must be a list or tensor"
        assert coeffs.dim(
        ) == 1, "Polynomial coefficients must be a 1-D tensor"

        # Handle linear case
        if coeffs.size(0) == 1:
            return self.mul(coeffs)

        # Compute terms of polynomial using exponentially growing tree
        terms = crypten.mpc.stack([self, self.square()])
        while terms.size(0) < coeffs.size(0):
            highest_term = terms[-1:].expand(terms.size())
            new_terms = getattr(terms, func)(highest_term)
            terms = crypten.cat([terms, new_terms])

        # Resize the coefficients for broadcast
        terms = terms[:coeffs.size(0)]
        for _ in range(terms.dim() - 1):
            coeffs = coeffs.unsqueeze(1)

        # Multiply terms by coefficients and sum
        return terms.mul(coeffs).sum(0)
コード例 #7
0
def _compute_pairwise_comparisons_for_steps(input_tensor, dim, steps):
    """
    Helper function that does pairwise comparisons by splitting input
    tensor for `steps` number of steps along dimension `dim`.
    """
    enc_tensor_reduced = input_tensor.clone()
    for _ in range(steps):
        m = enc_tensor_reduced.size(dim)
        x, y, remainder = enc_tensor_reduced.split([m // 2, m // 2, m % 2],
                                                   dim=dim)
        pairwise_max = crypten.where(x >= y, x, y)
        enc_tensor_reduced = crypten.cat([pairwise_max, remainder], dim=dim)
    return enc_tensor_reduced
コード例 #8
0
 def test_case5(input, encr_input):
     intermediate1 = input.mul(3.0)  # PyTorch
     intermediate2 = input.add(2.0).pow(2.0)
     intermediate3 = input.pow(2.0)
     output = (torch.cat([intermediate1, intermediate2,
                          intermediate3]).mul(0.5).sum())
     encr_intermediate1 = encr_input.mul(3.0)  # CrypTen
     encr_intermediate2 = encr_input.add(2.0).square()
     encr_intermediate3 = encr_input.pow(2.0)
     encr_output = (crypten.cat(
         [encr_intermediate1, encr_intermediate2,
          encr_intermediate3]).mul(0.5).sum())
     return output, encr_output
コード例 #9
0
ファイル: mpc_autograd_cnn.py プロジェクト: yangzpag/CrypTen
def run_mpc_autograd_cnn(
    context_manager=None,
    num_epochs=3,
    learning_rate=0.001,
    batch_size=5,
    print_freq=5,
    num_samples=100,
):
    """
    Args:
        context_manager: used for setting proxy settings during download.
    """
    crypten.init()

    data_alice, data_bob, train_labels = preprocess_mnist(context_manager)
    rank = comm.get().get_rank()

    # assumes at least two parties exist
    # broadcast dummy data with same shape to remaining parties
    if rank == 0:
        x_alice = data_alice
    else:
        x_alice = torch.empty(data_alice.size())

    if rank == 1:
        x_bob = data_bob
    else:
        x_bob = torch.empty(data_bob.size())

    # encrypt
    x_alice_enc = crypten.cryptensor(x_alice, src=0)
    x_bob_enc = crypten.cryptensor(x_bob, src=1)

    # combine feature sets
    x_combined_enc = crypten.cat([x_alice_enc, x_bob_enc], dim=2)
    x_combined_enc = x_combined_enc.unsqueeze(1)

    # reduce training set to num_samples
    x_reduced = x_combined_enc[:num_samples]
    y_reduced = train_labels[:num_samples]

    # encrypt plaintext model
    model_plaintext = CNN()
    dummy_input = torch.empty((1, 1, 28, 28))
    model = crypten.nn.from_pytorch(model_plaintext, dummy_input)
    model.train()
    model.encrypt()

    # encrypted training
    train_encrypted(x_reduced, y_reduced, model, num_epochs, learning_rate,
                    batch_size, print_freq)
コード例 #10
0
ファイル: test_autograd.py プロジェクト: xuleimath/CrypTen
 def test_case6(input, encr_input):
     idx1 = torch.tensor([[0, 2, 4, 3, 8]], dtype=torch.long)
     idx2 = torch.tensor([[5, 1, 3, 5, 2]], dtype=torch.long)
     idx3 = torch.tensor([[2, 3, 1]], dtype=torch.long)
     intermediate1 = input.gather(0, idx1).gather(1, idx3).pow(2.0)  # PyTorch
     intermediate2 = input.gather(0, idx2).gather(1, idx3).add(-2.0)
     output = torch.cat([intermediate1, intermediate2]).mul(0.5).sum()
     encr_intermediate1 = (
         encr_input.gather(0, idx1).gather(1, idx3).square()
     )  # CrypTen
     encr_intermediate2 = encr_input.gather(0, idx2).gather(1, idx3).add(-2.0)
     encr_output = (
         crypten.cat([encr_intermediate1, encr_intermediate2], dim=0)
         .mul(0.5)
         .sum()
     )
     return output, encr_output
コード例 #11
0
def load_encrypt_tensor(filename: str) -> crypten.CrypTensor:
    local_tensor = load_local_tensor(filename)
    rank = comm.get().get_rank()
    count = local_tensor.shape[0]

    encrypt_tensors = []
    for i, (name, feature_size) in enumerate(zip(names, feature_sizes)):
        if rank == i:
            assert local_tensor.shape[1] == feature_size, \
                f"{name} feature size should be {feature_size}, but get {local_tensor.shape[1]}"
            tensor = crypten.cryptensor(local_tensor, src=i)
        else:
            dummy_tensor = torch.zeros((count, feature_size),
                                       dtype=torch.float32)
            tensor = crypten.cryptensor(dummy_tensor, src=i)
        encrypt_tensors.append(tensor)

    res = crypten.cat(encrypt_tensors, dim=1)
    return res
コード例 #12
0
def _max_helper_double_log_recursive(enc_tensor, dim):
    """Recursive subroutine for computing max via double log reduction algorithm"""
    n = enc_tensor.size(dim)
    # compute integral sqrt(n) and the integer number of sqrt(n) size
    # vectors that can be extracted from n
    sqrt_n = int(math.sqrt(n))
    count_sqrt_n = n // sqrt_n
    # base case for recursion: no further splits along dimension dim
    if n == 1:
        return enc_tensor
    else:
        # split into tensors that can be broken into vectors of size sqrt(n)
        # and the remainder of the tensor
        size_arr = [sqrt_n * count_sqrt_n, n % sqrt_n]
        split_enc_tensor, remainder = enc_tensor.split(size_arr, dim=dim)

        # reshape such that dim holds sqrt_n and dim+1 holds count_sqrt_n
        updated_enc_tensor_size = [
            sqrt_n, enc_tensor.size(dim + 1) * count_sqrt_n
        ]
        size_arr = [enc_tensor.size(i) for i in range(enc_tensor.dim())]
        size_arr[dim], size_arr[dim + 1] = updated_enc_tensor_size
        split_enc_tensor = split_enc_tensor.reshape(size_arr)

        # recursive call on reshaped tensor
        split_enc_max = _max_helper_double_log_recursive(split_enc_tensor, dim)

        # reshape the result to have the (dim+1)th dimension as before
        # and concatenate the previously computed remainder
        size_arr[dim], size_arr[dim +
                                1] = [count_sqrt_n,
                                      enc_tensor.size(dim + 1)]
        enc_max_tensor = split_enc_max.reshape(size_arr)
        full_max_tensor = crypten.cat([enc_max_tensor, remainder], dim=dim)

        # call the max function on dimension dim
        enc_max, enc_arg_max = full_max_tensor.max(dim=dim,
                                                   keepdim=True,
                                                   method="pairwise")
        # compute max over the resulting reduced tensor with n^2 algorithm
        # note that the resulting one-hot vector we get here finds maxes only
        # over the reduced vector in enc_tensor_reduced, so we won't use it
        return enc_max
コード例 #13
0
ファイル: mpc_train.py プロジェクト: hitimr/ML_2020
def train_model_mpc():
    mem_before = get_process_memory()
    pid = comm.get().get_rank()
    ws = comm.get().world_size
    name = participants[pid]
    if pid == 0:
        print(f"Hello from the main process (rank#{pid} of {ws})!")
        print(f"My name is {name}.")
        print(f"My colleagues today are: ")
        print(participants)
    results = {
        "total": 0,
        "per_iter": [],
        "per_epoch": [],
        "inference": {
            "total": 0,
            "per_batch": [],
            "per_image": [],
            "average_per_image": 0
        },
        "mem_before": mem_before,
        "mem_after": None
    }
    LOG_STR = ""
    runtime = 0
    predictions = []
    targets = []
    class_correct = [0] * NUM_CLASSES
    class_total = [0] * NUM_CLASSES
    valid_loss_min = +np.inf

    # Setup log file per process
    postfix = f"{DATASET_NAME}_{ws}p_{pid}.log"
    memory_log = memory_dir / postfix
    runtimes_log = runtimes_dir / postfix
    results_log = results_dir / postfix

    # Load model
    dummy_image = torch.empty([1, NUM_CHANNELS, IMG_WIDTH,
                               IMG_HEIGHT])  # is that the right way around? :D
    #model = crypten.load(model_file_name, dummy_model=Net(), src=0)
    model_mpc = crypten.nn.from_pytorch(model, dummy_image)
    model_mpc.encrypt(src=0)

    if pid == 0:
        print("Gonna train now...")

    #model_mpc.eval()  # prep model for evaluation

    before_test.wait()

    for epoch in range(1, n_epochs + 1):
        # monitor losses
        train_loss = 0
        valid_loss = 0
        start = time()

        ###################
        # train the model #
        ###################
        iters = 0
        number_of_batches = len(train_loader)
        idx_to_show = np.arange(1, number_of_batches + 1,
                                int(number_of_batches / 100))
        for batch_idx, (data, target) in enumerate(train_loader):
            if pid == 0 and batch_idx in idx_to_show:
                print(
                    f"Batch: {(batch_idx+1) / (number_of_batches)*100:.2f}% --- {batch_idx+1}/{number_of_batches}"
                )
            start_iter = time()
            data_enc = []
            label_eye = torch.eye(10)
            target = label_eye[target]
            if ws > 2:
                for idx, batch in enumerate(
                        split_data_even(data, ws - 1, data.shape[0])):
                    data_enc.append(crypten.cryptensor(batch, src=idx + 1))
                #data_enc = crypten.cat(data_enc, dim=0)
            else:
                data_enc.append(crypten.cryptensor(data, src=1))

            for tensor in data_enc:
                tensor.set_grad_enabled = True

            target_enc = crypten.cryptensor(target)
            #target_enc.set_grad_enabled = True

            model_mpc.train()  # prep model for evaluation

            # forward pass: compute predicted outputs by passing inputs to the model
            output = []
            start_batch_inference = time()
            # In each batch, each participant except the model holder has an equal share of the batch
            # Iterate over each participants share
            for dat in data_enc:
                output.append(model_mpc(dat))
            stop_batch_inference = time()
            output = crypten.cat(output, dim=0)
            #output.set_grad_enabled = True
            # convert output probabilities to predicted class
            # pred = output.argmax(dim=1, one_hot=False)

            # calculate the loss
            if pid == 0:
                if output.shape != target_enc.shape:
                    print((output.shape, target_enc.shape))
            # loss = criterion(output, label) # pt
            loss = criterion(output, target_enc)  #.get_plain_text()

            # clear the gradients of all optimized variables
            model.zero_grad()
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # perform a single optimization step (parameter update)
            #optimizer.step()
            model_mpc.update_parameters(learning_rate)
            # update running training loss
            train_loss += loss.get_plain_text().item() * data.size(0)

            # ### compare predictions to true label
            # # decrypt predictions
            # pred = pred.get_plain_text()
            # correct = np.squeeze(pred.eq(target.data.view_as(pred)))
            # # calculate test accuracy for each object class
            # predictions.append(pred)
            # targets.append(target)

            results["per_iter"].append(time() - start_iter)
            results["inference"]["per_batch"].append(stop_batch_inference -
                                                     start_batch_inference)

            iters += 1
            iter_sync.wait()

        ###################
        # Save runtimes   #
        ###################
        stop = time()
        runtime = stop - start
        results["per_epoch"].append(runtime)
        results["total"] += runtime
        results["average_per_iter"] = np.mean(results["per_iter"])
        results["inference"]["total"] = np.sum(
            results["inference"]["per_batch"])
        results["inference"]["per_image"] = [
            x / batch_size for x in results["inference"]["per_batch"]
        ]
        results["inference"]["average_per_image"] = np.mean(
            results["inference"]["per_image"])
        # results = {
        #     "total": 0,
        #     "per_iter": [],
        #     "inference": {
        #         "total": 0,
        #         "per_batch": [],
        #         "per_image": [],
        #         "average_per_image": 0
        #     }
        # }

        ######################
        # validate the model #
        ######################
        model.eval()  # prep model for evaluation
        for data, label in valid_loader:
            data_enc = []
            if ws > 2:
                for idx, batch in enumerate(
                        split_data_even(data, ws - 1, data.shape[0])):
                    data_enc.append(crypten.cryptensor(batch, src=idx + 1))
                #data_enc = crypten.cat(data_enc, dim=0)
            else:
                data_enc.append(crypten.cryptensor(data, src=1))

            label_eye = torch.eye(10)
            label = label_eye[label]
            label_enc = crypten.cryptensor(label, src=0)
            # forward pass: compute predicted outputs by passing inputs to the model
            output = [model_mpc(dat) for dat in data_enc]
            output = crypten.cat(output, dim=0)
            if pid == 0:
                if output.shape != label_enc.shape:
                    print((output.shape, label_enc.shape))
            # calculate the loss
            loss = criterion(output, label_enc).get_plain_text()
            # update running validation loss
            valid_loss = loss.item() * data.size(0)

        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = train_loss / len(train_loader.sampler)
        valid_loss = valid_loss / len(valid_loader.sampler)

        tmp_str = f"Epoch: {epoch} \tTraining Loss: {train_loss:.6f} \tValidation Loss: {valid_loss:.6f}\n"
        LOG_STR += tmp_str
        if pid == 0:
            print(tmp_str)

        # save model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            model_dec = model_mpc.decrypt()

            tmp_str = f"Validation loss decreased ({valid_loss_min:.6f} --> {valid_loss:.6f}).  Saving model ...\n"
            LOG_STR += tmp_str
            if pid == 0:
                print(tmp_str)
                print(f"Saving model at {model_file_name}")
                #orch.save(model_mpc.state_dict(), model_file_name)
                torch.save(model_dec, model_file_name)
            valid_loss_min = valid_loss
            model_mpc.encrypt(src=0)
        log_memory(memory_log)

    if pid == 0:
        print("Done training...")

    after_test.wait()

    if pid == 0:
        print("Ouputing information...")

    # calculate and print avg test loss
    #test_loss = test_loss / len(test_loader.sampler)
    # if pid == 0:
    #     print(f"Test runtime: {runtime:5.2f}s\n\n")
    #     print(f"Test Loss: {test_loss:.6}\n")
    #     # Print accuracy per class
    #     for i in range(NUM_CLASSES):
    #         if class_total[i] > 0:
    #             print(
    #                 f"Test Accuracy of {i:5}: "
    #                 f"{100 * class_correct[i] / class_total[i]:3.0f}% "
    #                 f"({np.sum(class_correct[i]):4} / {np.sum(class_total[i]):4} )"
    #             )
    #         else:
    #             print(
    #                 f"Test Accuracy of {classes[i]}: N/A (no training examples)"
    #             )
    #     # Print overall accuracy
    #     print(
    #         f"\nTest Accuracy (Overall): {100. * np.sum(class_correct) / np.sum(class_total):3.0f}% "
    #         f"( {np.sum(class_correct)} / {np.sum(class_total)} )")

    # Gather log
    # LOG_STR = f"Rank: {pid}\nWorld_Size: {ws}\n\n"
    # LOG_STR += f"Test runtime: {runtime:5.2f}s\n"
    # LOG_STR += f"Test Loss: {test_loss:.6}\n"
    # LOG_STR += "\n"
    # for i in range(NUM_CLASSES):
    #     if class_total[i] > 0:
    #         LOG_STR += f"Test Accuracy of {i:5}: " \
    #               f"{100 * class_correct[i] / class_total[i]:3.0f}% " \
    #               f"({np.sum(class_correct[i]):4} / {np.sum(class_total[i]):4} )"
    #         LOG_STR += "\n"
    #     else:
    #         LOG_STR += f"Test Accuracy of {classes[i]}: N/A (no training examples)"
    #         LOG_STR += "\n"
    # LOG_STR += f"\nTest Accuracy (Overall): {100. * np.sum(class_correct) / np.sum(class_total):3.0f}% " + \
    #       f"( {np.sum(class_correct)} / {np.sum(class_total)} )"

    if pid == 0:
        print(LOG_STR)

    with open(f"./log/train/stdout_{pid}", "w") as f:
        f.write(LOG_STR)

    done.wait()
    mem_after = get_process_memory()
    results["mem_after"] = mem_after
    with open(results_log, 'w') as f:
        f.write(str(results))
    if pid == 0:
        with open(results_dir / f'latest_{pid}.txt', 'w') as f:
            f.write(str(results))

    return results
コード例 #14
0
def online_learner(
    sampler,
    backend="mpc",
    nr_iters=7,
    score_func=None,
    monitor_func=None,
    checkpoint_func=None,
    checkpoint_every=0,
):
    """
    Online learner that minimizes linear least squared loss.

    Args:
        sampler: An iterator that returns one sample at a time. Samples are
            assumed to be `dict`s with a `'context'` and a `'rewards'` field.
        backend: Which privacy protocol to use (default 'mpc').
        score_func: A closure that can be used to plug in exploration mechanisms.
        monitor_func: A closure that does logging.
        checkpoint_func: A closure that does checkpointing.
        nr_iters: Number of Newton-Rhapson iterations to use for private
            reciprocal.
    """

    # initialize some variables:
    total_reward = 0.0

    # initialize constructor for tensors:
    crypten.set_default_backend(backend)

    # loop over dataset:
    idx = 0
    for sample in sampler():
        start_t = time.time()

        # unpack sample:
        assert "context" in sample and "rewards" in sample, (
            "invalid sample: %s" % sample)

        context = crypten.cryptensor(sample["context"])
        num_features = context.nelement()
        num_arms = sample["rewards"].nelement()

        # initialization of model parameters:
        if idx == 0:

            # initialize accumulators for linear least squares:
            A_inv = [
                torch.eye(num_features).unsqueeze(0) for _ in range(num_arms)
            ]
            A_inv = crypten.cat([crypten.cryptensor(A) for A in A_inv])
            b = crypten.cryptensor(torch.zeros(num_arms, num_features))

            # compute initial weights for all arms:
            weights = b.unsqueeze(1).matmul(A_inv).squeeze(1)

        # compute score of all arms:
        scores = weights.matmul(context)

        # plug in exploration mechanism:
        if score_func is not None:
            score_func(scores, A_inv, b, context)

        onehot = scores.argmax()

        # In practice only one party opens the onehot vector in order to
        # take the action.
        selected_arm = onehot.get_plain_text().argmax()

        # Once the action is taken, the reward (a scalar) is observed by some
        # party and secret shared. Here we simulate that by selecting the
        # reward from the rewards vector and then sharing it.
        reward = crypten.cryptensor((sample["rewards"][selected_arm] >
                                     random.random()).view(1).float())

        # update linear least squares accumulators (using Sherman–Morrison
        # formula):
        A_inv_context = A_inv.matmul(context)
        numerator = A_inv_context.unsqueeze(1).mul(A_inv_context.unsqueeze(2))
        denominator = A_inv_context.matmul(context).add(1.0).view(-1, 1, 1)
        with crypten.mpc.ConfigManager("reciprocal_nr_iters", nr_iters):
            update = numerator.mul_(denominator.reciprocal())
        A_inv.sub_(update.mul_(onehot.view(-1, 1, 1)))
        b.add_(context.mul(reward).unsqueeze(0).mul_(onehot.unsqueeze(0)))

        # update model weights:
        weights = b.unsqueeze(1).matmul(A_inv).squeeze(1)

        # monitor learning progress: we use the plain reward only for
        # monitoring
        reward = reward.get_plain_text().item()
        total_reward += reward
        iter_time = time.time() - start_t
        if monitor_func is not None:
            monitor_func(idx, reward, total_reward, iter_time)
        idx += 1

        # checkpointing:
        if checkpoint_func is not None and idx % checkpoint_every == 0:
            checkpoint_func(
                idx,
                {
                    "A_inv": [AA.get_plain_text() for AA in A_inv],
                    "b": [bb.get_plain_text() for bb in b],
                },
            )

    # signal monitoring closure that we are done:
    if monitor_func is not None:
        monitor_func(idx, None, None, None, finished=True)
コード例 #15
0
 def forward(ctx, input, dim=0):
     ctx.save_multiple_for_backward((dim, [t.size(dim) for t in input]))
     return crypten.cat(input, dim=dim)
コード例 #16
0
 def forward(self, input):
     assert isinstance(input,
                       (list, tuple)), "input needs to be a list or tuple"
     assert len(input) >= 1, "need at least one tensor to concatenate"
     return crypten.cat(input, self.dimension)
コード例 #17
0
ファイル: util.py プロジェクト: vishalbelsare/CrypTen
 def repeat_row(tensor, dim, ind):
     device = tensor.device
     x = tensor.index_select(dim, torch.arange(ind, device=device))
     y = tensor.index_select(dim, torch.arange(ind, tensor.size(dim), device=device))
     repeated_row = tensor.index_select(dim, torch.tensor(ind - 1, device=device))
     return crypten.cat([x, repeated_row, y], dim=dim)
コード例 #18
0
 def repeat_row(tensor, dim, ind):
     x = tensor.index_select(dim, torch.arange(ind))
     y = tensor.index_select(dim, torch.arange(ind, tensor.size(dim)))
     repeated_row = tensor.index_select(dim, torch.tensor(ind - 1))
     return crypten.cat([x, repeated_row, y], dim=dim)
コード例 #19
0
def test_model_mpc():
    mem_before = get_process_memory()   
    runtime = 0
    pid = comm.get().get_rank()
    ws = comm.get().world_size
    name = participants[pid]
    if pid == 0:
        print(f"Hello from the main process (rank#{pid} of {ws})!")
        print(f"My name is {name}.")
        print(f"My colleagues today are: ")
        print(participants)
    results = {
        "total": 0,
        "per_iter": [],
        "inference": {
            "total": 0,
            "per_batch": [],
            "per_image": [],
            "average_per_image": 0
        },
        "mem_before": mem_before,
        "mem_after": None
    }
    predictions = []
    targets = []
    class_correct = [0] * NUM_CLASSES
    class_total = [0] * NUM_CLASSES

    # Setup log files per process
    postfix = f"{DATASET_NAME}_{ws}p_{pid}.log"
    memory_log = memory_dir / postfix
    runtimes_log = runtimes_dir / postfix
    results_log = results_dir / postfix

    #convert_legacy_config() # LEGACY
    #model_mpc = crypten.nn.from_pytorch(model, dummy_image)
    # Instantiate and load the model
    model = Net()
    # Load model
    dummy_image = torch.empty([1, NUM_CHANNELS, IMG_WIDTH,
                               IMG_HEIGHT])  # is that the right way around? :D
                               
    #model = crypten.load(model_file_name, dummy_model=model)

    model.load_state_dict(torch.load(model_file_name))
    #model = crypten.load(model_file_name, dummy_model=model, src=0)
    model_mpc = crypten.nn.from_pytorch(model, dummy_image)
    model_mpc.encrypt(src=0)

    if pid == 0:
        print("Gonna evaluate now...")

    test_loss = 0.0
    model_mpc.eval()  # prep model for evaluation

    before_test.wait()
    start = time()
    
    iters = 0
    for data, target in tqdm(test_loader, position=0):  #, desc=f"{name}"):
        start_iter = time()
        data_enc = []
        if ws > 2:
            for idx, batch in enumerate(
                    split_data_even(data, ws - 1, data.shape[0])):
                data_enc.append(crypten.cryptensor(batch, src=idx + 1))
            #data_enc = crypten.cat(data_enc, dim=0)
        else:
            data_enc.append(crypten.cryptensor(data, src=1))

        target_enc = crypten.cryptensor(target, src=0)

        # forward pass: compute predicted outputs by passing inputs to the model
        output = []
        start_batch_inference = time()
        # In each batch, each participant except the model holder has an equal share of the batch
        # Iterate over each participants share
        for dat in data_enc:
            output.append(model_mpc(dat))
        stop_batch_inference = time()

        output = crypten.cat(output, dim=0)
        # convert output probabilities to predicted class
        pred = output.argmax(dim=1, one_hot=False)
        # calculate the loss
        if pid == 0:
            if pred.shape != target_enc.shape:
                print((pred.shape, target_enc.shape))
        loss = criterion(pred, target_enc).get_plain_text()
        # update test loss
        test_loss += loss.item() * data.size(0)

        ### compare predictions to true label
        # decrypt predictions
        pred = pred.get_plain_text()
        correct = np.squeeze(pred.eq(target.data.view_as(pred)))
        # calculate test accuracy for each object class
        predictions.append(pred)
        targets.append(target)
        for i in range(len(target)):
            label = target.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1
        results["per_iter"].append(time() - start_iter)
        results["inference"]["per_batch"].append(stop_batch_inference - start_batch_inference)

        iters += 1
        iter_sync.wait()
        log_memory(memory_log)

    stop = time()
    runtime = stop - start
    results["total"] = runtime
    results["average_per_iter"] = np.mean(results["per_iter"])
    results["inference"]["total"] = np.sum(results["inference"]["per_batch"])
    results["inference"]["per_image"] = [x/batch_size for x in results["inference"]["per_batch"]]
    results["inference"]["average_per_image"] = np.mean(results["inference"]["per_image"])
    # results = {
    #     "total": 0,
    #     "per_iter": [],
    #     "inference": {
    #         "total": 0,
    #         "per_batch": [],
    #         "per_image": [],
    #         "average_per_image": 0
    #     }
    # }

    if pid == 0:
        print("Done evaluating...")

    after_test.wait()

    if pid == 0:
        print("Ouputing information...")

    # calculate and print avg test loss
    test_loss = test_loss / len(test_loader.sampler)
    # if pid == 0:
    #     print(f"Test runtime: {runtime:5.2f}s\n\n")
    #     print(f"Test Loss: {test_loss:.6}\n")
    #     # Print accuracy per class
    #     for i in range(NUM_CLASSES):
    #         if class_total[i] > 0:
    #             print(
    #                 f"Test Accuracy of {i:5}: "
    #                 f"{100 * class_correct[i] / class_total[i]:3.0f}% "
    #                 f"({np.sum(class_correct[i]):4} / {np.sum(class_total[i]):4} )"
    #             )
    #         else:
    #             print(
    #                 f"Test Accuracy of {classes[i]}: N/A (no training examples)"
    #             )
    #     # Print overall accuracy
    #     print(
    #         f"\nTest Accuracy (Overall): {100. * np.sum(class_correct) / np.sum(class_total):3.0f}% "
    #         f"( {np.sum(class_correct)} / {np.sum(class_total)} )")

    # Gather log
    LOG_STR = f"Rank: {pid}\nWorld_Size: {ws}\n\n"
    LOG_STR += f"Test runtime: {runtime:5.2f}s\n"
    LOG_STR += f"Test Loss: {test_loss:.6}\n"
    LOG_STR += "\n"
    for i in range(NUM_CLASSES):
        if class_total[i] > 0:
            LOG_STR += f"Test Accuracy of {i:5}: " \
                  f"{100 * class_correct[i] / class_total[i]:3.0f}% " \
                  f"({np.sum(class_correct[i]):4} / {np.sum(class_total[i]):4} )"
            LOG_STR += "\n"
        else:
            LOG_STR += f"Test Accuracy of {classes[i]}: N/A (no training examples)"
            LOG_STR += "\n"
    LOG_STR += f"\nTest Accuracy (Overall): {100. * np.sum(class_correct) / np.sum(class_total):3.0f}% " + \
          f"( {np.sum(class_correct)} / {np.sum(class_total)} )"
    
    if pid == 0:
        print(LOG_STR)

    with open(log_dir / f"stdout_{pid}", "w") as f:
        f.write(LOG_STR)
    
    done.wait()
    mem_after = get_process_memory()
    results["mem_after"] = mem_after
    with open(results_log, 'w') as f:
        f.write(str(results))
    if pid == 0:
        with open(results_dir / f'latest_{pid}.txt', 'w') as f:
            f.write(str(results))

    return results