コード例 #1
0
ファイル: register.py プロジェクト: zhaoyang626/tf-encrypted
def _add(converter, node: Any, inputs: List[str]) -> Any:
  a = converter.outputs[inputs[0]]
  b = converter.outputs[inputs[1]]

  if isinstance(a, tf.NodeDef):
    a_out = _nodef_to_public_pond(converter, a)
  else:
    a_out = a

  if isinstance(b, tf.NodeDef):
    b_out = _nodef_to_public_pond(converter, b)
  else:
    b_out = b

  return tfe.add(a_out, b_out)
コード例 #2
0
    def set_weights(self, weights, sess=None):
        """Update layer weights from numpy array or Public Tensors
      including denom.

    Arguments:
      weights: A list of Numpy arrays with shapes and types
          matching the output of layer.get_weights() or a list
          of private variables
      sess: tfe session"""

        if not sess:
            sess = KE.get_session()

        if isinstance(weights[0], np.ndarray):
            for i, w in enumerate(self.weights):
                if isinstance(w, PondPublicTensor):
                    shape = w.shape.as_list()
                    tfe_weights_pl = tfe.define_public_placeholder(shape)
                    fd = tfe_weights_pl.feed(weights[i].reshape(shape))
                    sess.run(tfe.assign(w, tfe_weights_pl), feed_dict=fd)
                else:
                    raise TypeError(
                        (
                            "Don't know how to handle weights "
                            "of type {}. Batchnorm expects public tensors"
                            "as weights"
                        ).format(type(w))
                    )

        elif isinstance(weights[0], PondPublicTensor):
            for i, w in enumerate(self.weights):
                shape = w.shape.as_list()
                sess.run(tfe.assign(w, weights[i].reshape(shape)))

        # Compute denom on public tensors before being lifted to private tensor
        denomtemp = tfe.reciprocal(
            tfe.sqrt(tfe.add(self.moving_variance, self.epsilon))
        )

        # Update denom as well when moving variance gets updated
        sess.run(tfe.assign(self.denom, denomtemp))