def secure_mean(collected_inputs): """ securely calculates the mean of the collected_inputs """ with tf.name_scope('secure_mean'): aggr_inputs = [ tfe.add_n(inputs) / len(inputs) for inputs in collected_inputs ] # Reveal aggregated values & cast to native tf.float32 aggr_inputs = [ tf.cast(inp.reveal().to_native(), tf.float32) for inp in aggr_inputs ] return aggr_inputs
def fit(self, training_players, summary=0, validation_split=None): """Trains the linear regressor. Arguments: training_players: Data owners used for joint training. Must implement the compute_estimators as a tfe.local_computation. summary: Controls what kind of summary statistics are generated after the linear regression fit. validation_split: Mimics the behavior of the Keras validation_split kwarg. """ if validation_split is not None: raise NotImplementedError() partial_estimators = [ player.compute_estimators(self.estimator_fn) for player in training_players ] for attr, partial_estimator in zip(self.components, zip(*partial_estimators)): setattr(self, attr, tfe.add_n(partial_estimator)) with tfe.Session() as sess: for k in self.components: op = getattr(self, k) setattr(self, k, sess.run(op.reveal())) tf_graph = tf.Graph() with tf_graph.as_default(): self._inverted_covariate_square = tf.linalg.inv( self.covariate_square) self.coefficients = tf.matmul(self._inverted_covariate_square, self.covariate_label_product) with tf.Session(graph=tf_graph) as sess: for k in ["_inverted_covariate_square", "coefficients"]: setattr(self, k, sess.run(getattr(self, k))) if not summary: return self return self.summarize(summary_level=summary)
def provide_input() -> tf.Tensor: # pick random tensor to be averaged return tf.random_normal(shape=(10, )) if __name__ == '__main__': # get input from inputters as private values inputs = [ tfe.define_private_input('inputter-0', provide_input), tfe.define_private_input('inputter-1', provide_input), tfe.define_private_input('inputter-2', provide_input), tfe.define_private_input('inputter-3', provide_input), tfe.define_private_input('inputter-4', provide_input), ] # sum all inputs and divide by count result = tfe.add_n(inputs) / len(inputs) def receive_output(average: tf.Tensor) -> tf.Operation: # simply print average return tf.print("Average:", average) # send result to receiver result_op = tfe.define_output('result-receiver', result, receive_output) # run a few times with tfe.Session() as sess: sess.run(result_op, tag='average')
DataOwner("data-owner-0", "./data/train.tfrecord", model_owner.build_update_step), DataOwner("data-owner-1", "./data/train.tfrecord", model_owner.build_update_step), DataOwner("data-owner-2", "./data/train.tfrecord", model_owner.build_update_step), ] model_grads = zip(*( data_owner.compute_gradient() for data_owner in data_owners )) with tf.name_scope('secure_aggregation'): aggregated_model_grads = [ tfe.add_n(grads) / len(grads) for grads in model_grads ] iteration_op = model_owner.update_model(*aggregated_model_grads) with tfe.Session(target=session_target) as sess: sess.run(tf.global_variables_initializer(), tag='init') for i in range(model_owner.ITERATIONS): if i % 100 == 0: print("Iteration {}".format(i)) sess.run(iteration_op, tag='iteration') else: sess.run(iteration_op)
data_owners = [ DataOwner("data-owner-0", "./data/train.tfrecord", model_owner.build_update_step), DataOwner("data-owner-1", "./data/train.tfrecord", model_owner.build_update_step), DataOwner("data-owner-2", "./data/train.tfrecord", model_owner.build_update_step), ] model_grads = zip(*(tfe.define_private_input(data_owner.player_name, data_owner.compute_gradient) for data_owner in data_owners)) with tf.name_scope('secure_aggregation'): aggregated_model_grads = [ tfe.add_n(grads) / len(grads) for grads in model_grads ] iteration_op = tfe.define_output(model_owner.player_name, aggregated_model_grads, model_owner.update_model) with tfe.Session(target=session_target) as sess: sess.run(tf.global_variables_initializer(), tag='init') for i in range(model_owner.ITERATIONS): if i % 100 == 0: print("Iteration {}".format(i)) sess.run(iteration_op, tag='iteration') else: sess.run(iteration_op)