def _batch_prepend_name_scope(self, to_rename, new_name_scope):
        """
        Construct a new GraphItem with all ops in `to_rename` under `new_name_scope`.

        Args:
            to_rename (set): Collection of ops to rename
            new_name_scope (str): The new name scope to prepend to all ops

        Returns:
            GraphItem
        """
        og_graph_def = self.graph_item.graph.as_graph_def()
        new_graph_def = graph_pb2.GraphDef()
        new_graph_def.library.Clear()
        new_graph_def.library.CopyFrom(og_graph_def.library)
        control_flow_contexts = {}

        for node in og_graph_def.node:
            op = self.graph_item.graph.get_operation_by_name(node.name)

            # Save control flow context to add it back later
            # Since it is not automatically set based on the attr's in the graph def
            ctx = op._get_control_flow_context()
            if ctx:
                control_flow_contexts[op.name] = ctx

            if op in to_rename:
                node.name = ops.prepend_name_scope(node.name, new_name_scope)

            # Rename inputs
            for idx, input_name in enumerate(node.input):
                input_op = self.graph_item.graph.get_operation_by_name(
                    get_op_name(input_name))
                if input_op in to_rename:
                    node.input[idx] = ops.prepend_name_scope(
                        input_name, new_name_scope)

            # Fix colocation
            for idx, s in enumerate(node.attr['_class'].list.s):
                name = s[len(COLOCATION_PREFIX):].decode('utf-8')
                if self.graph_item.graph.get_operation_by_name(
                        name) in to_rename:
                    node.attr['_class'].list.s[idx] = (
                        COLOCATION_PREFIX +
                        as_bytes(ops.prepend_name_scope(name, new_name_scope)))

            new_graph_def.node.append(node)

        # Re-add control flow contexts
        new_graph_item = GraphItem(graph_def=new_graph_def)
        for op in new_graph_item.graph.get_operations():
            if op.name in control_flow_contexts:
                op._set_control_flow_context(control_flow_contexts[op.name])

        return new_graph_item
Example #2
0
    def _delete_marked_ops(self, graph_item, name_scope):
        """
        Constructs a new GraphItem with all ops under `name_scope` removed.

        Args:
            graph_item (GraphItem): The current GraphItem.
            name_scope (str): The name scope to remove.

        Returns:
            GraphItem
        """
        graph_def = graph_item.graph.as_graph_def()
        new_graph_def = graph_pb2.GraphDef()
        new_graph_def.library.Clear()
        new_graph_def.library.CopyFrom(graph_def.library)
        control_flow_contexts = {}

        for node in graph_def.node:
            if parse_name_scope(node.name).startswith(name_scope):
                continue

            # Save control flow context to add it back later
            # Since it is not automatically set based on the attr's in the graph def
            op = graph_item.graph.get_operation_by_name(node.name)
            ctx = op._get_control_flow_context()
            if ctx:
                control_flow_contexts[op.name] = ctx

            for idx, input_name in enumerate(node.input):
                if parse_name_scope(input_name).startswith(name_scope):
                    node.input[idx] = ""

            for idx, s in enumerate(node.attr['_class'].list.s):
                name = s[len(COLOCATION_PREFIX):].decode('utf-8')
                if parse_name_scope(name).startswith(name_scope):
                    node.attr['_class'].list.s.remove(s)

            self._prune_graphdef_node_inputs(node)

            new_graph_def.node.append(node)

        # Re-add control flow contexts
        new_graph_item = GraphItem(graph_def=new_graph_def)
        for op in new_graph_item.graph.get_operations():
            if op.name in control_flow_contexts:
                op._set_control_flow_context(control_flow_contexts[op.name])

        return new_graph_item
Example #3
0
 def _initialize_graph(self):
     """Postpone the initialization of the member original_graph_item to the scoping time."""
     assert not tf_context.executing_eagerly()
     self._original_graph_item = GraphItem(graph=ops.get_default_graph())
Example #4
0
    def replicate(self, graph_item):
        """
        Replicate the entire graph as many times as num_replica.

        Args:
            graph_item: the original graph item

        Returns: The new graph item
        """
        item = GraphItem(graph=ops.Graph())
        fwd_ctx, bwd_ctx = self._collect_while_context(graph_item.graph)
        with item.graph.as_default():
            gdef = graph_item.graph.as_graph_def()
            for i in range(self._num_local_replicas):
                # Replicate ops
                with ops.device(self._replica_device_placer(replica_id=i)):
                    import_graph_def(gdef, name=replica_prefix(i))

                # Replicate while_loop context (control_flow) if needed.
                # The order matters -- We must replicate bwd context first, then forward context.
                # TODO(Zeya): To handle cases when there are nested while loops, in which we must replicate
                #  parent context first and then child context. See:
                #  https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/control_flow_ops.py#L938
                if bwd_ctx:
                    for ctx in bwd_ctx:
                        _ = WhileContext(context_def=ctx.to_proto(), grad_state=ctx._grad_state,
                                         import_scope=replica_prefix(i))
                if fwd_ctx:
                    for ctx in fwd_ctx:
                        _ = WhileContext(context_def=ctx.to_proto(), grad_state=ctx._grad_state,
                                         import_scope=replica_prefix(i))

            # update saver
            master_replica = 0
            if graph_item.info.savers:
                item.info.update_savers(
                    [Saver.from_proto(proto, import_scope=replica_prefix(master_replica)).to_proto()
                        for proto in graph_item.info.savers],
                    replace=False
                )

            # update gradient info
            for i in range(self._num_local_replicas):
                for g_name, t_name in graph_item.grad_target_name_pairs.items():
                    if isinstance(g_name, tuple):
                        new_g_name = (
                            ops.prepend_name_scope(g_name[0], replica_prefix(i)),
                            ops.prepend_name_scope(g_name[1], replica_prefix(i)),
                            ops.prepend_name_scope(g_name[2], replica_prefix(i)))
                    else:
                        new_g_name = ops.prepend_name_scope(g_name, replica_prefix(i))
                    new_t_name = ops.prepend_name_scope(t_name, replica_prefix(i))
                    item.extend_gradient_info_by_names(
                        grads=[new_g_name],
                        targets=[new_t_name]
                    )
                item.info.update_variables(
                    [_from_proto_fn(proto, import_scope=replica_prefix(i)).to_proto()
                        for proto in graph_item.info.variables],
                    replace=False
                )
                item.info.update_table_initializers(
                    [ops.prepend_name_scope(tb_init, replica_prefix(i))
                        for tb_init in graph_item.info.table_initializers],
                    replace=False
                )
        return item