Beispiel #1
0
    def minimize(self,
                 loss,
                 global_step=None,
                 var_list=None,
                 gate_gradients=1,
                 aggregation_method=None,
                 colocate_gradients_with_ops=False,
                 name=None,
                 grad_loss=None):
        rtt_get_logger().debug('begin to run StaticReplacePass...')

        # Create StaticReplacePass
        PassObj = StaticReplacePass()

        # Run the pass, and return new loss
        if isinstance(loss, rtt_ts.RttTensor):
            loss = PassObj.run(loss._raw)
        else:
            loss = PassObj.run(loss)
        rtt_get_logger().debug('end to run StaticReplacePass.')

        # generate secure gradient graph
        train_op = super(tf.train.GradientDescentOptimizer,
                         self).minimize(loss, global_step, var_list,
                                        gate_gradients, aggregation_method,
                                        colocate_gradients_with_ops, name,
                                        grad_loss)

        # generate message id
        MsgIdGenerator(regenerate=True).gen_msgid_and_notified(loss)

        return train_op
Beispiel #2
0
    def is_exist_secure_op(self, op):
        """
        Does the data flow(subgraph) have one or more Secure node

        :param op: the node is one part of the subgraph
        :return: True mean exist secure node, otherwise is False
        """

        # Take action based on the class of the source instance
        if isinstance(op, ops.Tensor):
            return self.is_exist_secure_op(op.op)

        elif isinstance(op, ops.Operation):
            # get the op def name
            op_def_name = op.op_def.name

            # If the op def name is secure op name, return true
            secure_info_val = list(self.secure_ops_infos.values())
            for item in secure_info_val:
                if op_def_name in item:
                    return True

            # If it has inputs, call this function recursively on each input.
            inputs_secure_flag = [self.is_exist_secure_op(x)
                              for x in op.inputs]
            if True in inputs_secure_flag:
                return True
            else:
                return False

        else:
            # If the op is not Tensor/Operation, return false.
            _errmsg = "Unkown the op {} at is_exist_secure_op() function.".format(str(op))
            rtt_get_logger().error(_errmsg)
            return False
Beispiel #3
0
 def gen_msgid_and_notified(self, loss):
     """
     generate the rosetta message id, and notified message id to player
     """
     msg_id = self._generate(loss)
     rtt_get_logger().debug(msg_id)
     py_msgid_handler = _rtt.msgid_handle.MsgIdHandle()
     py_msgid_handler.update_message_id_info(msg_id)
Beispiel #4
0
    def _create_secure_op_helper(self, src_op, secure_op_name, secure_op_creator_with_attr,
                                secure_op_creator_without_attr, secure_op_input_num, to_graph):
        """
        create secure op helper function

        :param src_op: source op instance, it's not secure op
        :param secure_op_name: secure op name, don't contain namespace
        :param secure_op_creator_with_attr: secure op creator with build-in attribute, eg: SecureMatmul's transpose_a and transpose_b attributes
        :param secure_op_creator_without_attr: secure op creator without build-in attribute, eg:SecureMul
        :param secure_op_input_num: secure op input numbers
        :param to_graph: dest graph
        :return: the secure op instance
        """

        # The name of the new instance
        if self.rtt_scope != '':
            new_secure_name = self.rtt_scope + '/' + secure_op_name + str(src_op._id)
            new_src_name = self.rtt_scope + '/' + src_op.name
        else:
            new_secure_name = secure_op_name + str(src_op._id)
            new_src_name = src_op.name

        # If it has an original_op parameter, copy it
        # the original_op must be none.
        if src_op._original_op is not None:
            raise NotImplementedError("not supported")

        # If it has control inputs, call this function recursively on each.
        new_control_inputs = [self.copy_and_replace_to_graph(x, to_graph)
                              for x in src_op.control_inputs]
        assert  len(new_control_inputs) == 0, "%s don't have control input edges" %secure_op_name

        # If it has inputs, call this function recursively on each.
        new_inputs = [self.copy_and_replace_to_graph(x, to_graph)
                      for x in src_op.inputs]
        assert  len(new_inputs) == secure_op_input_num, "{0} need {1} edges, but real edges is {2}".format(secure_op_name, 
                                                     secure_op_input_num, len(new_inputs))

        # checked secure op inputs
        self._checked_secure_op_inputs(secure_op_name, new_inputs)

        # create a secure op
        try:
            if self._is_op_support_const_attr(src_op.op_def.name):
                new_op = self._create_secure_op_with_const_attr(src_op, secure_op_creator_with_attr, 
                                                                secure_op_creator_without_attr, new_inputs,
                                                                secure_op_input_num, new_secure_name)
            else:
                new_op = self._create_secure_op_without_const_attr(src_op, secure_op_creator_with_attr, 
                                                                secure_op_creator_without_attr, new_inputs,
                                                                secure_op_input_num, new_secure_name)
        except Exception as e:
            rtt_get_logger().error(str(e))

        # Save the op info and return the op
        self.dc_op_info[new_src_name] = new_op

        return new_op
Beispiel #5
0
 def __get_numpy(self, file: str, is_x: bool, *args, **kwargs):
     arr = None
     try:
         df = pd.read_csv(file, *args, **kwargs)
         arr = df.to_numpy()
     except FileNotFoundError as e:
         pass
         rtt_get_logger().error(str(e))
     except Exception as e:
         pass
         # print(e)
     return arr
Beispiel #6
0
    def _deepcopy_op_helper(self, src_op, new_name, new_inputs,
                            new_ctrl_inputs, new_original_op, to_graph):
        """
        deepcopy op helper function

        :param src_op: source op instance.
        :param new_name: new op name.
        :param new_inputs: new op inputs.
        :param new_ctrl_inputs: new op control inputs.
        :param new_original_op: new original op, Used to associate the new `Operation` with an
        # existing `Operation` (for example, a replica with the op that was replicated).
        :param to_graph: dest graph
        :return: the deep copy op instance
        """

        # If the new_name op has already exist, find it and return
        if (new_name in self.dc_op_info.keys()):
            return self.dc_op_info[new_name]

        # Make a new node_def based on that of the original.
        # An instance of tensorflow.core.framework.graph_pb2.NodeDef,
        # it stores String-based info such as name, device and type of the op.
        # Unique to every Operation instance.
        new_node_def = deepcopy(src_op.node_def)

        # Change the name
        new_node_def.name = new_name

        # Copy the other inputs needed for initialization
        output_types = src_op._output_types[:]
        input_types = src_op._input_types[:]

        # Make a copy of the op_def too.
        # Its unique to every _type_ of Operation.
        op_def = deepcopy(src_op.op_def)

        # Initialize a new Operation instance
        try:
            new_op = ops.Operation(new_node_def,
                                to_graph,
                                new_inputs,
                                output_types,
                                new_ctrl_inputs,
                                input_types,
                                new_original_op,
                                op_def)
        except Exception as e:
            rtt_get_logger().error(str(e))

        # Save the op info
        self.dc_op_info[new_name] = new_op

        return new_op
Beispiel #7
0
    def _is_need_secure_op(self, op):
        """
        Replace the op with secure op or deep copy the op based on the data flow.

        :param op: judge the op is secure op or not based on the data flow
        :return: True is secure op , otherwise is False
        """

        # The name of the new instance
        if self.rtt_scope != '':
            new_name = self.rtt_scope + '/' + op.name
        else:
            new_name = op.name

        # If a variable is trainable, so need secure op to replace the op
        if new_name in self.all_vars:
            if new_name in self.train_vars:
                return True
            else:
                return False

        # if the op name is in self.need_secure_op_name,
        # so need secure op to replace the op
        if new_name in self.need_secure_op_name:
            return True

        # Take action based on the class of the source instance
        if isinstance(op, ops.Tensor):
            return self._is_need_secure_op(op.op)

        elif isinstance(op, ops.Operation):
            # If the op is placeholder, return true
            if (op.op_def.name == "Placeholder"):
                return True

            # If it has inputs, call this function recursively on each.
            inputs_secure_flag = [
                self._is_need_secure_op(x) for x in op.inputs
            ]
            if True in inputs_secure_flag:
                return True
            else:
                return False

        else:
            # If the op is not Tensor/Operation, return false.
            rtt_get_logger().error("Unkown the op: " + str(op))
            return False
Beispiel #8
0
def _convert_tensorflow_tensor(tensor):
    """ convert tensorflow native tensor to rtt tensor """

    if tensor.dtype in (tf.string, ):
        # supported as-is
        rtt_get_logger().debug("debug in _convert_tensorflow_tensor:" +
                               str(tensor))
        return RttTensor(rtt_ops.tf_to_rtt(tensor))

    if tensor.dtype in (tf.int32, tf.int64, tf.float16, tf.float32,
                        tf.float64):
        # supported as strings
        tensor = tf.as_string(tensor)
        return RttTensor(rtt_ops.tf_to_rtt(tensor))

    raise ValueError(
        "Don't know how to convert TensorFlow tensor with dtype '{}'".format(
            tensor.dtype))
    def create_secure_op(self, src_op, to_graph):
        """
        Create a secure op based on the source op, e.g.: Mul op => SecureMul op

        :param src_op: source op instance.
        :param to_graph: dest graph
        :return: secure op instance
        """
        
        # get new dest op define name and create functor with the dest secure op define name.
        try:
            secure_def_name = self.secure_ops_infos[src_op.op_def.name.lower()][self.SECURE_DEF_NAME_IDX]
            secure_create_func = self.secure_ops_infos[src_op.op_def.name.lower()][self.SECURE_CREATOR_IDX]
            secure_op = self.secure_ops_infos[src_op.op_def.name.lower()][self.SECURE_OP_IDX]
            assert secure_create_func != None, "secure creator is none."
            return secure_create_func(src_op, secure_def_name, secure_op, to_graph)  
        except KeyError:
            _errmsg = "tf native op {} does not implemented Secure op".format(src_op.op_def.name.lower())
            rtt_get_logger().error(_errmsg)
Beispiel #10
0
    def create_secure_op(self, src_op, to_graph):
        """
        Create a secure op based on the source op, e.g.: Mul op => SecureMul op

        :param src_op: source op instance.
        :param to_graph: dest graph
        :return: secure op instance
        """
        
        # get new dest op define name and create functor with the dest secure op define name.
        try:
            secure_def_name = self.secure_ops_infos[src_op.op_def.name.lower()][self.SECURE_DEF_NAME_IDX]
            secure_op_creator_with_attr = self.secure_ops_infos[src_op.op_def.name.lower()][self.SECURE_OP_CREATOR_WITH_ATTR_IDX]
            secure_op_creator_without_attr = self.secure_ops_infos[src_op.op_def.name.lower()][self.SECURE_OP_CREATOR_WITHOUT_ATTR_IDX]
            secure_op_inputs = self.secure_ops_infos[src_op.op_def.name.lower()][self.SECURE_OP_INPUTS_IDX]

            return self._create_secure_op_helper(src_op, secure_def_name, secure_op_creator_with_attr,
                                                secure_op_creator_without_attr, secure_op_inputs, to_graph)
        except KeyError:
            _errmsg = "tf native op {} does not implemented Secure op".format(src_op.op_def.name.lower())
            rtt_get_logger().error(_errmsg)
Beispiel #11
0
    def save_op(self, filename_tensor, saveables):
        rtt_get_logger().debug("DEBUG: save_op")
        """Create an Op to save 'saveables'.

        This is intended to be overridden by subclasses that want to generate
        different Ops.

        Args:
        filename_tensor: String Tensor.
        saveables: A list of BaseSaverBuilder.SaveableObject objects.

        Returns:
        An Operation that save the variables.

        Raises:
        RuntimeError: (implementation detail) if "self._write_version" is an
            unexpected value.
        """
        # pylint: disable=protected-access
        tensor_names = []
        tensors = []
        tensor_slices = []
        for saveable in saveables:
            for spec in saveable.specs:
                tensor_names.append(spec.name)
                tensors.append(spec.tensor)
                tensor_slices.append(spec.slice_spec)
        if self._write_version == saver_pb2.SaverDef.V1:
            return io_ops._save(
                filename=filename_tensor,
                tensor_names=tensor_names,
                tensors=tensors,
                tensor_slices=tensor_slices)
        elif self._write_version == saver_pb2.SaverDef.V2:
            # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix
            # of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>".
            return SecureSaveV2(filename_tensor, tensor_names, tensor_slices,
                                tensors)
        else:
            raise RuntimeError("Unexpected write_version: " + self._write_version)
Beispiel #12
0
        self.stream.writelines(datas)
        self.stream.flush()
    def __getattr__(self, attr):
        return getattr(self.stream, attr)
sys.stdout = ZeroBufferOut(sys.stdout)


if 'ROSETTA_MPC_128' in os.environ and os.environ['ROSETTA_MPC_128'] == 'ON':
    tf_dpass_lib = os.path.dirname(__file__) + '/../lib128/libtf-dpass.so'
else:
    tf_dpass_lib = os.path.dirname(__file__) + '/../libtf-dpass.so'
tf_dpass_lib = os.path.dirname(__file__) + '/../libtf-dpass.so'

dpass = None
if 'ROSETTA_DPASS' in os.environ and os.environ['ROSETTA_DPASS'] == 'OFF':
    rtt_get_logger().debug('NOT load library: {}, disable dynamic pass.'.format(tf_dpass_lib))
else:
    dpass = tf.load_op_library(tf_dpass_lib)
    rtt_get_logger().debug('load library: {}'.format(tf_dpass_lib))

__doc__ = """
    This is LatticeX Rosetta.
"""

# RTT
from latticex.rosetta.rtt import *
# MPC
from latticex.rosetta.secure import *
# ZK for the future
# HE for the future
Beispiel #13
0
def __check(task_id):
    if not py_protocol_handler.is_activated(task_id):
        errmsg = "Protocol have not activated. See rtt.activate()."
        rtt_get_logger().error(errmsg)
        raise Exception(errmsg)
Beispiel #14
0
    def _create_secure_op_helper(self, src_op, secure_op_name,
                                 secure_op_input_num, secure_op_creator,
                                 to_graph):
        """
        create secure op helper function

        :param src_op: source op instance, it's not secure op
        :param secure_op_name: secure op name, don't contain namespace
        :param secure_op_input_num: secure op input numbers
        :param secure_op_creator: secure op creator
        :param to_graph: dest graph
        :return: the secure op instance
        """

        # The name of the new instance
        if self.rtt_scope != '':
            new_secure_name = self.rtt_scope + '/' + secure_op_name
            new_src_name = self.rtt_scope + '/' + src_op.name
        else:
            new_secure_name = secure_op_name
            new_src_name = src_op.name

        # If it has an original_op parameter, copy it
        # the original_op must be none.
        if src_op._original_op is not None:
            raise NotImplementedError("not supported")

        # If it has control inputs, call this function recursively on each.
        new_control_inputs = [
            self.copy_and_replace_to_graph(x, to_graph)
            for x in src_op.control_inputs
        ]
        assert len(new_control_inputs
                   ) == 0, "%s don't have control input edges" % secure_op_name

        # If it has inputs, call this function recursively on each.
        new_inputs = [
            self.copy_and_replace_to_graph(x, to_graph) for x in src_op.inputs
        ]
        assert len(
            new_inputs
        ) == secure_op_input_num, "{0} need {1} edges, but real edges is {2}".format(
            secure_op_name, secure_op_input_num, len(new_inputs))

        # checked secure op inputs
        self._checked_secure_op_inputs(secure_op_name, new_inputs)

        # create secure op
        try:
            if secure_op_input_num == 1:
                new_op = secure_op_creator(new_inputs[0],
                                           name=new_secure_name).op
            elif secure_op_input_num == 2:
                if self._is_bin_op_unsupport_const_attr(src_op.op_def.name):
                    new_op = self._create_unsupport_const_attr_secure_bin_op(
                        src_op, secure_op_creator, new_inputs, new_secure_name)
                else:
                    new_op = secure_op_creator(
                        new_inputs[0],
                        new_inputs[1],
                        name=new_secure_name,
                        lh_is_const=self._secure_input_is_const(new_inputs[0]),
                        rh_is_const=self._secure_input_is_const(
                            new_inputs[1])).op
            elif secure_op_input_num == 3:
                new_op = secure_op_creator(new_inputs[0],
                                           new_inputs[1],
                                           new_inputs[2],
                                           name=new_secure_name).op
            elif secure_op_input_num == 4:
                new_op = secure_op_creator(new_inputs[0],
                                           new_inputs[1],
                                           new_inputs[2],
                                           new_inputs[3],
                                           name=new_secure_name).op
            elif secure_op_input_num == 5:
                new_op = secure_op_creator(new_inputs[0],
                                           new_inputs[1],
                                           new_inputs[2],
                                           new_inputs[3],
                                           new_inputs[4],
                                           name=new_secure_name).op
            elif secure_op_input_num == 6:
                new_op = secure_op_creator(new_inputs[0],
                                           new_inputs[1],
                                           new_inputs[2],
                                           new_inputs[3],
                                           new_inputs[4],
                                           new_inputs[5],
                                           name=new_secure_name).op
            else:
                raise ValueError("the %s op don't have %d inputs" %
                                 secure_op_name % secure_op_input_num)
        except Exception as e:
            rtt_get_logger().error(str(e))

        # Save the op info
        self.dc_op_info[new_src_name] = new_op

        # Return the op
        return new_op
Beispiel #15
0
    def _is_need_secure_op(self, op):
        """
        Replace the op with secure op or deep copy the op based on the data flow.

        :param op: judge the op is secure op or not based on the data flow
        :return: True is secure op , otherwise is False
        """

        # The name of the new instance
        if self.rtt_scope != '':
            new_name = self.rtt_scope + '/' + op.name
        else:
            new_name = op.name

        # If a variable is trainable, so need secure op to replace the op
        if new_name in self.all_vars:
            if new_name in self.train_vars:
                return True
            else:
                return False

        # if the op name is in self.need_secure_op_name,
        # so need secure op to replace the op
        if new_name in self.need_secure_op_name:
            return True

        # Take action based on the class of the source instance
        if isinstance(op, ops.Tensor):
            return self._is_need_secure_op(op.op)

        elif isinstance(op, ops.Operation):
            # If the op is placeholder, return true
            if (op.op_def.name == "Placeholder"):
                return True

            # If the op is IteratorV2, subgraph data flow analysis,
            # check the subgraph(root op is MakeIterator) has need secure op?
            if (op.op_def.name == "IteratorV2"):
                for unit_op in tf.get_default_graph().get_operations():
                    if (unit_op.op_def.name == "MakeIterator"):
                        assert len(
                            unit_op.inputs
                        ) == 2, "MakeIterator op inputs is incorrect."
                        if (unit_op.inputs[0].op == op):
                            return self._is_need_secure_op(unit_op.inputs[1])
                        elif (unit_op.inputs[1].op == op):
                            return self._is_need_secure_op(unit_op.inputs[0])

            # If the op is BatchDatasetV2, return true
            if (op.op_def.name == "BatchDatasetV2"):
                return True

            # If it has inputs, call this function recursively on each.
            inputs_secure_flag = [
                self._is_need_secure_op(x) for x in op.inputs
            ]
            if True in inputs_secure_flag:
                return True
            else:
                return False

        else:
            # If the op is not Tensor/Operation, return false.
            rtt_get_logger().error("Unkown the op: " + str(op))
            return False
Beispiel #16
0
from latticex.rosetta.controller.controller_base_ import _rtt
from latticex.rosetta.controller.common_util import rtt_get_logger

# A simple way to cope commandline arguments
# ########################### commandline arguments
import argparse
import os

_parser = argparse.ArgumentParser(description="LatticeX Rosetta")
_parser.add_argument('--party_id',
                     type=int,
                     help="Party ID",
                     required=False,
                     default=-1,
                     choices=[0, 1, 2])
_parser.add_argument('--cfgfile',
                     type=str,
                     help="Config File",
                     default=os.path.abspath('.') + "/CONFIG.json")
_args, _unparsed = _parser.parse_known_args()
# ###########################
# the party id can also be configed in CONFIG.json file.
# If you config it by commandline, this will override the one in config file.
_party_id = _args.party_id
_cfgfile = _args.cfgfile

rtt_get_logger().info('party id: {} with config json file: {}'.format(
    _party_id, _cfgfile))

py_protocol_handler = _rtt.protocol.ProtocolHandler()