コード例 #1
0
    def _clone_model(self, model, perturbations, dst_scope):
        ''' make a copy of model and connect the resulting sub-graph to
            input ops of the original graph and parameter assignments by
            perturbator.    
        '''
        def not_placeholder_or_trainvar_filter(op):
            # print(op.name)
            if op.type == 'Placeholder':              # evaluation sub-graphs will be fed from original placeholders
                return False
            for var_name in self.tvars:
                if op.name.startswith(var_name):      # remove Some/Var/(read,assign,...) -- will be replaced with perturbations
                    return False
            return True

        ops_without_inputs = ge.filter_ops(model.ops, not_placeholder_or_trainvar_filter)
        # print("ModelOPS=========================")
        # for o in ops_without_inputs:
        #     print(o.name, o.type)
        # remove init op from clone if already present
        try:
            ops_without_inputs.remove(self.work_graph.get_operation_by_name("init"))
        except:
            pass
        clone_sgv = ge.make_view(ops_without_inputs)
        clone_sgv = clone_sgv.remove_unused_ops(control_inputs=True)

        input_replacements = {}
        for t in clone_sgv.inputs:
            if t.name in perturbations.keys():                  # input from trainable var --> replace with perturbation
                input_replacements[t] = perturbations[t.name]
            else:                                               # otherwise take input from original graph
                input_replacements[t] = self.work_graph.get_tensor_by_name(t.name)
        return ge.copy_with_input_replacements(clone_sgv, input_replacements, dst_scope=dst_scope)
コード例 #2
0
 def test_copy_assert(self):
   tf.reset_default_graph()
   a = tf.constant(1)
   b = tf.constant(1)
   eq = tf.equal(a, b)
   assert_op = tf.Assert(eq, [a, b])
   with tf.control_dependencies([assert_op]):
     _ = tf.add(a, b)
   sgv = ge.make_view([assert_op, eq.op, a.op, b.op])
   copier = ge.Transformer()
   copied_sgv, info = copier(sgv, sgv.graph, "", "")
   new_assert_op = info.transformed(assert_op)
   self.assertIsNotNone(new_assert_op)
コード例 #3
0
 def test_copy_assert(self):
     ops.reset_default_graph()
     a = constant_op.constant(1)
     b = constant_op.constant(1)
     eq = math_ops.equal(a, b)
     assert_op = control_flow_ops.Assert(eq, [a, b])
     with ops.control_dependencies([assert_op]):
         _ = math_ops.add(a, b)
     sgv = ge.make_view([assert_op, eq.op, a.op, b.op])
     copier = ge.Transformer()
     copied_sgv, info = copier(sgv, sgv.graph, "", "")
     new_assert_op = info.transformed(assert_op)
     self.assertIsNotNone(new_assert_op)
コード例 #4
0
ファイル: clone_graph.py プロジェクト: teobaluta/provero
def clone_subgraph(outputs, mappings, clone_scope=''):
    NON_REPLICABLE = {
        'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
        'MutableHashTableV2', 'MutableHashTableOfTensors',
        'MutableHashTableOfTensorsV2', 'MutableDenseHashTable',
        'MutableDenseHashTableV2', 'VarHandleOp',
        'BoostedTreesEnsembleResourceHandleOp'
    }
    ops = ge.get_backward_walk_ops(outputs, stop_at_ts=mappings.keys())
    ops_replicate = [op for op in ops if op.type not in NON_REPLICABLE]
    sgv = ge.make_view(*ops_replicate)
    _, info = ge.copy_with_input_replacements(sgv,
                                              mappings,
                                              dst_scope=clone_scope)
    return info.transformed(outputs)
コード例 #5
0
    def _duplicate_graph(self, graph, vars_to_replace, name='Duplicated'):
        """
        Duplicates loss graph with swapped variables.
        :return: Swapped graph.
        """
        if graph in vars_to_replace:
            return vars_to_replace[graph]

        operations = []

        def get_ops(t):
            if t.op.type != 'VariableV2' and t.op.type != 'Placeholder':
                operations.append(t.op)
                for i in t.op.inputs:
                    if i not in vars_to_replace:
                        get_ops(i)

        get_ops(graph)

        sgv = graph_editor.make_view(operations)
        with ops.name_scope(name):
            new_view, _ = graph_editor.copy_with_input_replacements(
                sgv, vars_to_replace)
            return new_view.outputs[sgv.output_index(graph)]
コード例 #6
0
    def __init__(self, meta_graph_def, import_path, internal_scope):
        eprint("import_path, internal_scope", import_path, internal_scope)

        internal_scope = internal_scope or import_path

        self._meta_graph_def = meta_graph_def
        self._internal_scope = internal_scope

        import_scope = import_path
        tensors = []
        functions = {}  # name -> (scope, inputs, outputs)
        pattern = re.compile("%s/([^/]+)(?:$|(/inputs/|/outputs/).)" %
                             internal_scope)
        graph_def = meta_graph_def.graph_def
        for n in graph_def.node:
            m = pattern.match(n.name)
            if not m:
                continue

            name = m.group(1)
            function_component = m.group(2)
            if not function_component:
                tensors.append(n.name)
                continue

            if name not in functions:
                function = (name, [], [])
                functions[name] = function
            else:
                function = functions[name]

            if function_component.startswith("/inputs/"):
                function[1].append(n.name)
            else:
                output_prefix_len = len(internal_scope) + len(name) + len(
                    function_component) + 1
                function[2].append(n.name[output_prefix_len:])

        try:
            with tf.name_scope(None):
                tf.train.import_meta_graph(meta_graph_def,
                                           import_scope=import_path)
        except KeyError as e:
            nodes = [
                n.name for n in tf.get_default_graph().as_graph_def().node
            ]
            nodes.sort()
            eprint('error, but got nodes', nodes)
            raise e

        nodes = [n.name for n in tf.get_default_graph().as_graph_def().node]
        nodes.sort()
        eprint('no error, but got nodes', nodes)

        g = tf.get_default_graph()
        exports = {}  # name -> tensor | function
        for name in tensors:
            eprint("tensor", name, "%s/%s" % (import_scope, name))
            exports[name] = g.get_operation_by_name("%s/%s" %
                                                    (import_scope, name))

        for name, (scope, full_input_names, output_names) in functions.items():
            # inputs = [g.get_tensor_by_name("%s/%s:0" % (import_scope, full_input_name)) for full_input_name in full_input_names]
            inputs = [
                g.get_tensor_by_name("%s/%s:0" %
                                     (import_scope, full_input_name))
                for full_input_name in full_input_names
            ]
            source_scope = "%s/%s/%s" % (import_path, internal_scope, scope)
            eprint("sgv_scope",
                   "%s/%s/%s" % (import_path, internal_scope, scope), g)
            source_pattern = "%s/(?:_|outputs)/.*" % source_scope
            sgv = make_view(source_pattern, graph=g)
            eprint("sgv.inputs", list(sgv.inputs))
            exports[name] = SubGraphViewFunction(name, sgv, source_scope,
                                                 inputs, output_names)

        self._exports = exports