def validate_mpc(dataloader: DataLoader, model: crypten.nn.Module, loss: crypten.nn.Module): model.eval() outs = [] true_ys = [] total_loss = None count = len(dataloader) for xs, ys in tqdm(dataloader, file=sys.stdout): out = model(xs) loss_val = loss(out, ys) outs.append(out) true_ys.append(ys) if total_loss is None: total_loss = loss_val.detach() else: total_loss += loss_val.detach() total_loss = total_loss.get_plain_text().item() all_out = crypten.cat(outs, dim=0) all_prob = all_out.sigmoid() all_prob = all_prob.get_plain_text() pred_ys = torch.where(all_prob > 0.5, 1, 0).tolist() pred_probs = all_prob.tolist() true_ys = crypten.cat(true_ys, dim=0) true_ys = true_ys.get_plain_text().tolist() return total_loss / count, precision_score(true_ys, pred_ys), recall_score(true_ys, pred_ys), \ roc_auc_score(true_ys, pred_probs)
def randn(*sizes, device=None): """ Returns a tensor with normally distributed elements. Samples are generated using the Box-Muller transform with optimizations for numerical precision and MPC efficiency. """ u = crypten.rand(*sizes, device=device).flatten() odd_numel = u.numel() % 2 == 1 if odd_numel: u = crypten.cat([u, crypten.rand((1, ), device=device)]) n = u.numel() // 2 u1 = u[:n] u2 = u[n:] # Radius = sqrt(- 2 * log(u1)) r2 = -2 * u1.log(input_in_01=True) r = r2.sqrt() # Theta = cos(2 * pi * u2) or sin(2 * pi * u2) cos, sin = u2.sub(0.5).mul(6.28318531).cossin() # Generating 2 independent normal random variables using x = r.mul(sin) y = r.mul(cos) z = crypten.cat([x, y]) if odd_numel: z = z[1:] return z.view(*sizes)
def test_case9(input, encr_input): intermediate1 = torch.cat([input, input]) intermediate2 = intermediate1.mean(0, keepdim=True) output = torch.cat([intermediate2, intermediate1], dim=0).sum() encr_intermediate1 = crypten.cat([encr_input, encr_input]) encr_intermediate2 = encr_intermediate1.mean(0, keepdim=True) encr_output = crypten.cat([encr_intermediate2, encr_intermediate1]).sum() return output, encr_output
def test_case8(input, encr_input): intermediate1 = input.add(3.0) intermediate2 = torch.cat([input, intermediate1]) intermediate3 = intermediate2.pow(2.0) output = torch.cat([input, intermediate2, intermediate3]).add(-1).sum() encr_intermediate1 = encr_input.add(3.0) encr_intermediate2 = crypten.cat([encr_input, encr_intermediate1]) encr_intermediate3 = encr_intermediate2.pow(2.0) encr_output = (crypten.cat( [encr_input, encr_intermediate2, encr_intermediate3]).add(-1).sum()) return output, encr_output
def extend_row(tensor, dim, start_ind, end_ind): if reduction == "mean": extended_value = tensor.index_select(dim, torch.arange(start_ind, end_ind)) extended_value = extended_value.mean(dim, keepdim=True) elif reduction == "max": extended_value = tensor.index_select(dim, torch.tensor(start_ind)) else: raise ValueError(f"Invalid reduction {reduction} for adaptive pooling.") if start_ind == 0: return crypten.cat([extended_value, tensor], dim=dim) x = tensor.index_select(dim, torch.arange(start_ind)) y = tensor.index_select(dim, torch.arange(start_ind, tensor.size(dim))) return crypten.cat([x, extended_value, y], dim=dim)
def polynomial(self, coeffs, func="mul"): """Computes a polynomial function on a tensor with given coefficients, `coeffs`, that can be a list of values or a 1-D tensor. Coefficients should be ordered from the order 1 (linear) term first, ending with the highest order term. (Constant is not included). """ # Coefficient input type-checking if isinstance(coeffs, list): coeffs = torch.tensor(coeffs) assert torch.is_tensor(coeffs) or crypten.is_encrypted_tensor( coeffs), "Polynomial coefficients must be a list or tensor" assert coeffs.dim( ) == 1, "Polynomial coefficients must be a 1-D tensor" # Handle linear case if coeffs.size(0) == 1: return self.mul(coeffs) # Compute terms of polynomial using exponentially growing tree terms = crypten.mpc.stack([self, self.square()]) while terms.size(0) < coeffs.size(0): highest_term = terms[-1:].expand(terms.size()) new_terms = getattr(terms, func)(highest_term) terms = crypten.cat([terms, new_terms]) # Resize the coefficients for broadcast terms = terms[:coeffs.size(0)] for _ in range(terms.dim() - 1): coeffs = coeffs.unsqueeze(1) # Multiply terms by coefficients and sum return terms.mul(coeffs).sum(0)
def _compute_pairwise_comparisons_for_steps(input_tensor, dim, steps): """ Helper function that does pairwise comparisons by splitting input tensor for `steps` number of steps along dimension `dim`. """ enc_tensor_reduced = input_tensor.clone() for _ in range(steps): m = enc_tensor_reduced.size(dim) x, y, remainder = enc_tensor_reduced.split([m // 2, m // 2, m % 2], dim=dim) pairwise_max = crypten.where(x >= y, x, y) enc_tensor_reduced = crypten.cat([pairwise_max, remainder], dim=dim) return enc_tensor_reduced
def test_case5(input, encr_input): intermediate1 = input.mul(3.0) # PyTorch intermediate2 = input.add(2.0).pow(2.0) intermediate3 = input.pow(2.0) output = (torch.cat([intermediate1, intermediate2, intermediate3]).mul(0.5).sum()) encr_intermediate1 = encr_input.mul(3.0) # CrypTen encr_intermediate2 = encr_input.add(2.0).square() encr_intermediate3 = encr_input.pow(2.0) encr_output = (crypten.cat( [encr_intermediate1, encr_intermediate2, encr_intermediate3]).mul(0.5).sum()) return output, encr_output
def run_mpc_autograd_cnn( context_manager=None, num_epochs=3, learning_rate=0.001, batch_size=5, print_freq=5, num_samples=100, ): """ Args: context_manager: used for setting proxy settings during download. """ crypten.init() data_alice, data_bob, train_labels = preprocess_mnist(context_manager) rank = comm.get().get_rank() # assumes at least two parties exist # broadcast dummy data with same shape to remaining parties if rank == 0: x_alice = data_alice else: x_alice = torch.empty(data_alice.size()) if rank == 1: x_bob = data_bob else: x_bob = torch.empty(data_bob.size()) # encrypt x_alice_enc = crypten.cryptensor(x_alice, src=0) x_bob_enc = crypten.cryptensor(x_bob, src=1) # combine feature sets x_combined_enc = crypten.cat([x_alice_enc, x_bob_enc], dim=2) x_combined_enc = x_combined_enc.unsqueeze(1) # reduce training set to num_samples x_reduced = x_combined_enc[:num_samples] y_reduced = train_labels[:num_samples] # encrypt plaintext model model_plaintext = CNN() dummy_input = torch.empty((1, 1, 28, 28)) model = crypten.nn.from_pytorch(model_plaintext, dummy_input) model.train() model.encrypt() # encrypted training train_encrypted(x_reduced, y_reduced, model, num_epochs, learning_rate, batch_size, print_freq)
def test_case6(input, encr_input): idx1 = torch.tensor([[0, 2, 4, 3, 8]], dtype=torch.long) idx2 = torch.tensor([[5, 1, 3, 5, 2]], dtype=torch.long) idx3 = torch.tensor([[2, 3, 1]], dtype=torch.long) intermediate1 = input.gather(0, idx1).gather(1, idx3).pow(2.0) # PyTorch intermediate2 = input.gather(0, idx2).gather(1, idx3).add(-2.0) output = torch.cat([intermediate1, intermediate2]).mul(0.5).sum() encr_intermediate1 = ( encr_input.gather(0, idx1).gather(1, idx3).square() ) # CrypTen encr_intermediate2 = encr_input.gather(0, idx2).gather(1, idx3).add(-2.0) encr_output = ( crypten.cat([encr_intermediate1, encr_intermediate2], dim=0) .mul(0.5) .sum() ) return output, encr_output
def load_encrypt_tensor(filename: str) -> crypten.CrypTensor: local_tensor = load_local_tensor(filename) rank = comm.get().get_rank() count = local_tensor.shape[0] encrypt_tensors = [] for i, (name, feature_size) in enumerate(zip(names, feature_sizes)): if rank == i: assert local_tensor.shape[1] == feature_size, \ f"{name} feature size should be {feature_size}, but get {local_tensor.shape[1]}" tensor = crypten.cryptensor(local_tensor, src=i) else: dummy_tensor = torch.zeros((count, feature_size), dtype=torch.float32) tensor = crypten.cryptensor(dummy_tensor, src=i) encrypt_tensors.append(tensor) res = crypten.cat(encrypt_tensors, dim=1) return res
def _max_helper_double_log_recursive(enc_tensor, dim): """Recursive subroutine for computing max via double log reduction algorithm""" n = enc_tensor.size(dim) # compute integral sqrt(n) and the integer number of sqrt(n) size # vectors that can be extracted from n sqrt_n = int(math.sqrt(n)) count_sqrt_n = n // sqrt_n # base case for recursion: no further splits along dimension dim if n == 1: return enc_tensor else: # split into tensors that can be broken into vectors of size sqrt(n) # and the remainder of the tensor size_arr = [sqrt_n * count_sqrt_n, n % sqrt_n] split_enc_tensor, remainder = enc_tensor.split(size_arr, dim=dim) # reshape such that dim holds sqrt_n and dim+1 holds count_sqrt_n updated_enc_tensor_size = [ sqrt_n, enc_tensor.size(dim + 1) * count_sqrt_n ] size_arr = [enc_tensor.size(i) for i in range(enc_tensor.dim())] size_arr[dim], size_arr[dim + 1] = updated_enc_tensor_size split_enc_tensor = split_enc_tensor.reshape(size_arr) # recursive call on reshaped tensor split_enc_max = _max_helper_double_log_recursive(split_enc_tensor, dim) # reshape the result to have the (dim+1)th dimension as before # and concatenate the previously computed remainder size_arr[dim], size_arr[dim + 1] = [count_sqrt_n, enc_tensor.size(dim + 1)] enc_max_tensor = split_enc_max.reshape(size_arr) full_max_tensor = crypten.cat([enc_max_tensor, remainder], dim=dim) # call the max function on dimension dim enc_max, enc_arg_max = full_max_tensor.max(dim=dim, keepdim=True, method="pairwise") # compute max over the resulting reduced tensor with n^2 algorithm # note that the resulting one-hot vector we get here finds maxes only # over the reduced vector in enc_tensor_reduced, so we won't use it return enc_max
def train_model_mpc(): mem_before = get_process_memory() pid = comm.get().get_rank() ws = comm.get().world_size name = participants[pid] if pid == 0: print(f"Hello from the main process (rank#{pid} of {ws})!") print(f"My name is {name}.") print(f"My colleagues today are: ") print(participants) results = { "total": 0, "per_iter": [], "per_epoch": [], "inference": { "total": 0, "per_batch": [], "per_image": [], "average_per_image": 0 }, "mem_before": mem_before, "mem_after": None } LOG_STR = "" runtime = 0 predictions = [] targets = [] class_correct = [0] * NUM_CLASSES class_total = [0] * NUM_CLASSES valid_loss_min = +np.inf # Setup log file per process postfix = f"{DATASET_NAME}_{ws}p_{pid}.log" memory_log = memory_dir / postfix runtimes_log = runtimes_dir / postfix results_log = results_dir / postfix # Load model dummy_image = torch.empty([1, NUM_CHANNELS, IMG_WIDTH, IMG_HEIGHT]) # is that the right way around? :D #model = crypten.load(model_file_name, dummy_model=Net(), src=0) model_mpc = crypten.nn.from_pytorch(model, dummy_image) model_mpc.encrypt(src=0) if pid == 0: print("Gonna train now...") #model_mpc.eval() # prep model for evaluation before_test.wait() for epoch in range(1, n_epochs + 1): # monitor losses train_loss = 0 valid_loss = 0 start = time() ################### # train the model # ################### iters = 0 number_of_batches = len(train_loader) idx_to_show = np.arange(1, number_of_batches + 1, int(number_of_batches / 100)) for batch_idx, (data, target) in enumerate(train_loader): if pid == 0 and batch_idx in idx_to_show: print( f"Batch: {(batch_idx+1) / (number_of_batches)*100:.2f}% --- {batch_idx+1}/{number_of_batches}" ) start_iter = time() data_enc = [] label_eye = torch.eye(10) target = label_eye[target] if ws > 2: for idx, batch in enumerate( split_data_even(data, ws - 1, data.shape[0])): data_enc.append(crypten.cryptensor(batch, src=idx + 1)) #data_enc = crypten.cat(data_enc, dim=0) else: data_enc.append(crypten.cryptensor(data, src=1)) for tensor in data_enc: tensor.set_grad_enabled = True target_enc = crypten.cryptensor(target) #target_enc.set_grad_enabled = True model_mpc.train() # prep model for evaluation # forward pass: compute predicted outputs by passing inputs to the model output = [] start_batch_inference = time() # In each batch, each participant except the model holder has an equal share of the batch # Iterate over each participants share for dat in data_enc: output.append(model_mpc(dat)) stop_batch_inference = time() output = crypten.cat(output, dim=0) #output.set_grad_enabled = True # convert output probabilities to predicted class # pred = output.argmax(dim=1, one_hot=False) # calculate the loss if pid == 0: if output.shape != target_enc.shape: print((output.shape, target_enc.shape)) # loss = criterion(output, label) # pt loss = criterion(output, target_enc) #.get_plain_text() # clear the gradients of all optimized variables model.zero_grad() # backward pass: compute gradient of the loss with respect to model parameters loss.backward() # perform a single optimization step (parameter update) #optimizer.step() model_mpc.update_parameters(learning_rate) # update running training loss train_loss += loss.get_plain_text().item() * data.size(0) # ### compare predictions to true label # # decrypt predictions # pred = pred.get_plain_text() # correct = np.squeeze(pred.eq(target.data.view_as(pred))) # # calculate test accuracy for each object class # predictions.append(pred) # targets.append(target) results["per_iter"].append(time() - start_iter) results["inference"]["per_batch"].append(stop_batch_inference - start_batch_inference) iters += 1 iter_sync.wait() ################### # Save runtimes # ################### stop = time() runtime = stop - start results["per_epoch"].append(runtime) results["total"] += runtime results["average_per_iter"] = np.mean(results["per_iter"]) results["inference"]["total"] = np.sum( results["inference"]["per_batch"]) results["inference"]["per_image"] = [ x / batch_size for x in results["inference"]["per_batch"] ] results["inference"]["average_per_image"] = np.mean( results["inference"]["per_image"]) # results = { # "total": 0, # "per_iter": [], # "inference": { # "total": 0, # "per_batch": [], # "per_image": [], # "average_per_image": 0 # } # } ###################### # validate the model # ###################### model.eval() # prep model for evaluation for data, label in valid_loader: data_enc = [] if ws > 2: for idx, batch in enumerate( split_data_even(data, ws - 1, data.shape[0])): data_enc.append(crypten.cryptensor(batch, src=idx + 1)) #data_enc = crypten.cat(data_enc, dim=0) else: data_enc.append(crypten.cryptensor(data, src=1)) label_eye = torch.eye(10) label = label_eye[label] label_enc = crypten.cryptensor(label, src=0) # forward pass: compute predicted outputs by passing inputs to the model output = [model_mpc(dat) for dat in data_enc] output = crypten.cat(output, dim=0) if pid == 0: if output.shape != label_enc.shape: print((output.shape, label_enc.shape)) # calculate the loss loss = criterion(output, label_enc).get_plain_text() # update running validation loss valid_loss = loss.item() * data.size(0) # print training/validation statistics # calculate average loss over an epoch train_loss = train_loss / len(train_loader.sampler) valid_loss = valid_loss / len(valid_loader.sampler) tmp_str = f"Epoch: {epoch} \tTraining Loss: {train_loss:.6f} \tValidation Loss: {valid_loss:.6f}\n" LOG_STR += tmp_str if pid == 0: print(tmp_str) # save model if validation loss has decreased if valid_loss <= valid_loss_min: model_dec = model_mpc.decrypt() tmp_str = f"Validation loss decreased ({valid_loss_min:.6f} --> {valid_loss:.6f}). Saving model ...\n" LOG_STR += tmp_str if pid == 0: print(tmp_str) print(f"Saving model at {model_file_name}") #orch.save(model_mpc.state_dict(), model_file_name) torch.save(model_dec, model_file_name) valid_loss_min = valid_loss model_mpc.encrypt(src=0) log_memory(memory_log) if pid == 0: print("Done training...") after_test.wait() if pid == 0: print("Ouputing information...") # calculate and print avg test loss #test_loss = test_loss / len(test_loader.sampler) # if pid == 0: # print(f"Test runtime: {runtime:5.2f}s\n\n") # print(f"Test Loss: {test_loss:.6}\n") # # Print accuracy per class # for i in range(NUM_CLASSES): # if class_total[i] > 0: # print( # f"Test Accuracy of {i:5}: " # f"{100 * class_correct[i] / class_total[i]:3.0f}% " # f"({np.sum(class_correct[i]):4} / {np.sum(class_total[i]):4} )" # ) # else: # print( # f"Test Accuracy of {classes[i]}: N/A (no training examples)" # ) # # Print overall accuracy # print( # f"\nTest Accuracy (Overall): {100. * np.sum(class_correct) / np.sum(class_total):3.0f}% " # f"( {np.sum(class_correct)} / {np.sum(class_total)} )") # Gather log # LOG_STR = f"Rank: {pid}\nWorld_Size: {ws}\n\n" # LOG_STR += f"Test runtime: {runtime:5.2f}s\n" # LOG_STR += f"Test Loss: {test_loss:.6}\n" # LOG_STR += "\n" # for i in range(NUM_CLASSES): # if class_total[i] > 0: # LOG_STR += f"Test Accuracy of {i:5}: " \ # f"{100 * class_correct[i] / class_total[i]:3.0f}% " \ # f"({np.sum(class_correct[i]):4} / {np.sum(class_total[i]):4} )" # LOG_STR += "\n" # else: # LOG_STR += f"Test Accuracy of {classes[i]}: N/A (no training examples)" # LOG_STR += "\n" # LOG_STR += f"\nTest Accuracy (Overall): {100. * np.sum(class_correct) / np.sum(class_total):3.0f}% " + \ # f"( {np.sum(class_correct)} / {np.sum(class_total)} )" if pid == 0: print(LOG_STR) with open(f"./log/train/stdout_{pid}", "w") as f: f.write(LOG_STR) done.wait() mem_after = get_process_memory() results["mem_after"] = mem_after with open(results_log, 'w') as f: f.write(str(results)) if pid == 0: with open(results_dir / f'latest_{pid}.txt', 'w') as f: f.write(str(results)) return results
def online_learner( sampler, backend="mpc", nr_iters=7, score_func=None, monitor_func=None, checkpoint_func=None, checkpoint_every=0, ): """ Online learner that minimizes linear least squared loss. Args: sampler: An iterator that returns one sample at a time. Samples are assumed to be `dict`s with a `'context'` and a `'rewards'` field. backend: Which privacy protocol to use (default 'mpc'). score_func: A closure that can be used to plug in exploration mechanisms. monitor_func: A closure that does logging. checkpoint_func: A closure that does checkpointing. nr_iters: Number of Newton-Rhapson iterations to use for private reciprocal. """ # initialize some variables: total_reward = 0.0 # initialize constructor for tensors: crypten.set_default_backend(backend) # loop over dataset: idx = 0 for sample in sampler(): start_t = time.time() # unpack sample: assert "context" in sample and "rewards" in sample, ( "invalid sample: %s" % sample) context = crypten.cryptensor(sample["context"]) num_features = context.nelement() num_arms = sample["rewards"].nelement() # initialization of model parameters: if idx == 0: # initialize accumulators for linear least squares: A_inv = [ torch.eye(num_features).unsqueeze(0) for _ in range(num_arms) ] A_inv = crypten.cat([crypten.cryptensor(A) for A in A_inv]) b = crypten.cryptensor(torch.zeros(num_arms, num_features)) # compute initial weights for all arms: weights = b.unsqueeze(1).matmul(A_inv).squeeze(1) # compute score of all arms: scores = weights.matmul(context) # plug in exploration mechanism: if score_func is not None: score_func(scores, A_inv, b, context) onehot = scores.argmax() # In practice only one party opens the onehot vector in order to # take the action. selected_arm = onehot.get_plain_text().argmax() # Once the action is taken, the reward (a scalar) is observed by some # party and secret shared. Here we simulate that by selecting the # reward from the rewards vector and then sharing it. reward = crypten.cryptensor((sample["rewards"][selected_arm] > random.random()).view(1).float()) # update linear least squares accumulators (using Sherman–Morrison # formula): A_inv_context = A_inv.matmul(context) numerator = A_inv_context.unsqueeze(1).mul(A_inv_context.unsqueeze(2)) denominator = A_inv_context.matmul(context).add(1.0).view(-1, 1, 1) with crypten.mpc.ConfigManager("reciprocal_nr_iters", nr_iters): update = numerator.mul_(denominator.reciprocal()) A_inv.sub_(update.mul_(onehot.view(-1, 1, 1))) b.add_(context.mul(reward).unsqueeze(0).mul_(onehot.unsqueeze(0))) # update model weights: weights = b.unsqueeze(1).matmul(A_inv).squeeze(1) # monitor learning progress: we use the plain reward only for # monitoring reward = reward.get_plain_text().item() total_reward += reward iter_time = time.time() - start_t if monitor_func is not None: monitor_func(idx, reward, total_reward, iter_time) idx += 1 # checkpointing: if checkpoint_func is not None and idx % checkpoint_every == 0: checkpoint_func( idx, { "A_inv": [AA.get_plain_text() for AA in A_inv], "b": [bb.get_plain_text() for bb in b], }, ) # signal monitoring closure that we are done: if monitor_func is not None: monitor_func(idx, None, None, None, finished=True)
def forward(ctx, input, dim=0): ctx.save_multiple_for_backward((dim, [t.size(dim) for t in input])) return crypten.cat(input, dim=dim)
def forward(self, input): assert isinstance(input, (list, tuple)), "input needs to be a list or tuple" assert len(input) >= 1, "need at least one tensor to concatenate" return crypten.cat(input, self.dimension)
def repeat_row(tensor, dim, ind): device = tensor.device x = tensor.index_select(dim, torch.arange(ind, device=device)) y = tensor.index_select(dim, torch.arange(ind, tensor.size(dim), device=device)) repeated_row = tensor.index_select(dim, torch.tensor(ind - 1, device=device)) return crypten.cat([x, repeated_row, y], dim=dim)
def repeat_row(tensor, dim, ind): x = tensor.index_select(dim, torch.arange(ind)) y = tensor.index_select(dim, torch.arange(ind, tensor.size(dim))) repeated_row = tensor.index_select(dim, torch.tensor(ind - 1)) return crypten.cat([x, repeated_row, y], dim=dim)
def test_model_mpc(): mem_before = get_process_memory() runtime = 0 pid = comm.get().get_rank() ws = comm.get().world_size name = participants[pid] if pid == 0: print(f"Hello from the main process (rank#{pid} of {ws})!") print(f"My name is {name}.") print(f"My colleagues today are: ") print(participants) results = { "total": 0, "per_iter": [], "inference": { "total": 0, "per_batch": [], "per_image": [], "average_per_image": 0 }, "mem_before": mem_before, "mem_after": None } predictions = [] targets = [] class_correct = [0] * NUM_CLASSES class_total = [0] * NUM_CLASSES # Setup log files per process postfix = f"{DATASET_NAME}_{ws}p_{pid}.log" memory_log = memory_dir / postfix runtimes_log = runtimes_dir / postfix results_log = results_dir / postfix #convert_legacy_config() # LEGACY #model_mpc = crypten.nn.from_pytorch(model, dummy_image) # Instantiate and load the model model = Net() # Load model dummy_image = torch.empty([1, NUM_CHANNELS, IMG_WIDTH, IMG_HEIGHT]) # is that the right way around? :D #model = crypten.load(model_file_name, dummy_model=model) model.load_state_dict(torch.load(model_file_name)) #model = crypten.load(model_file_name, dummy_model=model, src=0) model_mpc = crypten.nn.from_pytorch(model, dummy_image) model_mpc.encrypt(src=0) if pid == 0: print("Gonna evaluate now...") test_loss = 0.0 model_mpc.eval() # prep model for evaluation before_test.wait() start = time() iters = 0 for data, target in tqdm(test_loader, position=0): #, desc=f"{name}"): start_iter = time() data_enc = [] if ws > 2: for idx, batch in enumerate( split_data_even(data, ws - 1, data.shape[0])): data_enc.append(crypten.cryptensor(batch, src=idx + 1)) #data_enc = crypten.cat(data_enc, dim=0) else: data_enc.append(crypten.cryptensor(data, src=1)) target_enc = crypten.cryptensor(target, src=0) # forward pass: compute predicted outputs by passing inputs to the model output = [] start_batch_inference = time() # In each batch, each participant except the model holder has an equal share of the batch # Iterate over each participants share for dat in data_enc: output.append(model_mpc(dat)) stop_batch_inference = time() output = crypten.cat(output, dim=0) # convert output probabilities to predicted class pred = output.argmax(dim=1, one_hot=False) # calculate the loss if pid == 0: if pred.shape != target_enc.shape: print((pred.shape, target_enc.shape)) loss = criterion(pred, target_enc).get_plain_text() # update test loss test_loss += loss.item() * data.size(0) ### compare predictions to true label # decrypt predictions pred = pred.get_plain_text() correct = np.squeeze(pred.eq(target.data.view_as(pred))) # calculate test accuracy for each object class predictions.append(pred) targets.append(target) for i in range(len(target)): label = target.data[i] class_correct[label] += correct[i].item() class_total[label] += 1 results["per_iter"].append(time() - start_iter) results["inference"]["per_batch"].append(stop_batch_inference - start_batch_inference) iters += 1 iter_sync.wait() log_memory(memory_log) stop = time() runtime = stop - start results["total"] = runtime results["average_per_iter"] = np.mean(results["per_iter"]) results["inference"]["total"] = np.sum(results["inference"]["per_batch"]) results["inference"]["per_image"] = [x/batch_size for x in results["inference"]["per_batch"]] results["inference"]["average_per_image"] = np.mean(results["inference"]["per_image"]) # results = { # "total": 0, # "per_iter": [], # "inference": { # "total": 0, # "per_batch": [], # "per_image": [], # "average_per_image": 0 # } # } if pid == 0: print("Done evaluating...") after_test.wait() if pid == 0: print("Ouputing information...") # calculate and print avg test loss test_loss = test_loss / len(test_loader.sampler) # if pid == 0: # print(f"Test runtime: {runtime:5.2f}s\n\n") # print(f"Test Loss: {test_loss:.6}\n") # # Print accuracy per class # for i in range(NUM_CLASSES): # if class_total[i] > 0: # print( # f"Test Accuracy of {i:5}: " # f"{100 * class_correct[i] / class_total[i]:3.0f}% " # f"({np.sum(class_correct[i]):4} / {np.sum(class_total[i]):4} )" # ) # else: # print( # f"Test Accuracy of {classes[i]}: N/A (no training examples)" # ) # # Print overall accuracy # print( # f"\nTest Accuracy (Overall): {100. * np.sum(class_correct) / np.sum(class_total):3.0f}% " # f"( {np.sum(class_correct)} / {np.sum(class_total)} )") # Gather log LOG_STR = f"Rank: {pid}\nWorld_Size: {ws}\n\n" LOG_STR += f"Test runtime: {runtime:5.2f}s\n" LOG_STR += f"Test Loss: {test_loss:.6}\n" LOG_STR += "\n" for i in range(NUM_CLASSES): if class_total[i] > 0: LOG_STR += f"Test Accuracy of {i:5}: " \ f"{100 * class_correct[i] / class_total[i]:3.0f}% " \ f"({np.sum(class_correct[i]):4} / {np.sum(class_total[i]):4} )" LOG_STR += "\n" else: LOG_STR += f"Test Accuracy of {classes[i]}: N/A (no training examples)" LOG_STR += "\n" LOG_STR += f"\nTest Accuracy (Overall): {100. * np.sum(class_correct) / np.sum(class_total):3.0f}% " + \ f"( {np.sum(class_correct)} / {np.sum(class_total)} )" if pid == 0: print(LOG_STR) with open(log_dir / f"stdout_{pid}", "w") as f: f.write(LOG_STR) done.wait() mem_after = get_process_memory() results["mem_after"] = mem_after with open(results_log, 'w') as f: f.write(str(results)) if pid == 0: with open(results_dir / f'latest_{pid}.txt', 'w') as f: f.write(str(results)) return results