Beispiel #1
0
    def run(self, loss):
        """
        the static pass run function, the pass is deep copy or replace op based on the op data flow.

        :param loss: A `Tensor` containing the value to minimize.
        :return: the new loss tensor
        """

        # Judge the loss has be replace with secure graph already
        # reture the secure graph if the loss has be replace,
        # otherwise we need to exec copy_and_replace_to_graph() function.
        if (loss in StaticReplacePass.tf_graph_mapto_secure_graph.keys()):
            secure_loss = StaticReplacePass.tf_graph_mapto_secure_graph[loss]
            MsgIdGenerator().gen_msgid_and_notified(secure_loss)
            return secure_loss

        # Get default graph
        to_graph = tf.get_default_graph()

        # Deep copy and replace source op
        secure_loss = self.copy_and_replace_to_graph(loss, to_graph)

        # Save the secure loss
        StaticReplacePass.tf_graph_mapto_secure_graph[loss] = secure_loss

        # Generate message id and notified to player
        MsgIdGenerator().gen_msgid_and_notified(secure_loss)

        return secure_loss
Beispiel #2
0
    def minimize(self,
                 loss,
                 global_step=None,
                 var_list=None,
                 gate_gradients=1,
                 aggregation_method=None,
                 colocate_gradients_with_ops=False,
                 name=None,
                 grad_loss=None):
        rtt_get_logger().debug('begin to run StaticReplacePass...')

        # Create StaticReplacePass
        PassObj = StaticReplacePass()

        # Run the pass, and return new loss
        if isinstance(loss, rtt_ts.RttTensor):
            loss = PassObj.run(loss._raw)
        else:
            loss = PassObj.run(loss)
        rtt_get_logger().debug('end to run StaticReplacePass.')

        # generate secure gradient graph
        train_op = super(tf.train.GradientDescentOptimizer,
                         self).minimize(loss, global_step, var_list,
                                        gate_gradients, aggregation_method,
                                        colocate_gradients_with_ops, name,
                                        grad_loss)

        # generate message id
        MsgIdGenerator(regenerate=True).gen_msgid_and_notified(loss)

        return train_op