コード例 #1
0
def secure_model(model):
    session = K.get_session()
    min_graph = graph_util.convert_variables_to_constants(
        session, session.graph_def, [node.op.name for node in model.outputs])
    tf.train.write_graph(min_graph, '/tmp', 'model.pb', as_text=False)

    graph_def, inputs = load_graph('/tmp/model.pb')

    c = tfe.convert.convert.Converter()
    y = c.convert(remove_training_nodes(graph_def), tfe.convert.register(),
                  'input-provider', inputs)

    return PrivateModel(y)
コード例 #2
0
def secure_model(model, **converter_kwargs):
  """Secure a plaintext model from the current session."""
  session = K.get_session()
  min_graph = graph_util.convert_variables_to_constants(
      session, session.graph_def, [node.op.name for node in model.outputs])
  graph_fname = 'model.pb'
  tf.train.write_graph(min_graph, _TMPDIR, graph_fname, as_text=False)

  graph_def, inputs = load_graph(os.path.join(_TMPDIR, graph_fname))

  c = tfe.convert.convert.Converter(tfe.convert.registry(), **converter_kwargs)
  y = c.convert(remove_training_nodes(graph_def), 'input-provider', inputs)

  return PrivateModel(y)