Пример #1
0
    def _share_initializer(self, graph_item, var_op_name, master_replica=0):
        """Share the initializers of all replica variables to use initializer on replica=master_replica."""
        # find the initial value of the var on master_replica
        master_var_op = graph_item.graph.get_operation_by_name(
            ops.prepend_name_scope(var_op_name,
                                   replica_prefix(master_replica)))
        master_var = graph_item.trainable_var_op_to_var[master_var_op]
        master_init_tensor = graph_item.graph.get_tensor_by_name(
            master_var.initial_value.name)
        master_init_op = master_init_tensor.op
        # set the device of the init ops to reside on the chief device
        master_init_device = device_spec.DeviceSpecV2.from_string(master_init_op.device) \
            .replace(task=0)
        master_init_op._set_device_from_string(master_init_device.to_string())

        for i in range(0, self.num_replicas):
            if i == master_replica:
                continue
            var_op = graph_item.graph.get_operation_by_name(
                ops.prepend_name_scope(var_op_name, replica_prefix(i)))
            var = graph_item.trainable_var_op_to_var[var_op]
            init_op = graph_item.graph.get_tensor_by_name(
                var.initial_value.name).op
            init_assign_op = get_consumers(init_op)[0]
            init_assign_op._update_input(1, master_init_tensor)
Пример #2
0
 def _handle_read(new_graph_item, var_op, partitioned_var):
     partitioned_var_tensor = partitioned_var.as_tensor()
     for op in get_consumers(var_op):
         op = new_graph_item.graph.get_operation_by_name(
             ops.prepend_name_scope(op.name, AUTODIST_TO_DELETE_SCOPE)
         )
         if op.type == "ResourceGather":
             # Only Resource Variable needs to be taken care of
             #   because ResourceGather consumes resource tensor rather than the tensor of read_var_op
             # Question: Is there any case where the op.type == "ResourceGather"
             #  but we can't use embedding_lookup_v2 to reconstruct the op consuming a partitioned resource
             # The second input to a ResourceGather op is always the indices per the opdef
             emb_lookup = embedding_ops.embedding_lookup_v2(partitioned_var, ids=op.inputs[1])
             update_consumers(get_consumers(op), op.outputs[0], emb_lookup)
         if is_read_var_op(op, version=1):
             # Without our modification, Reference Vars in TF have a read op associated with them.
             # TF can sometimes look for this and expect it to exist (e.g. in graph.as_graph_element)
             # so we add one back to avoid errors.
             # read_out is already the output tensor of the generated identity op
             read_out = array_ops.identity(partitioned_var_tensor,
                                           name=ops.prepend_name_scope("read", var_op.name))
             update_consumers(get_consumers(op), op.outputs[0], read_out)
         elif is_read_var_op(op, version=2):
             read_out = array_ops.identity(partitioned_var_tensor,
                                           name=ops.prepend_name_scope("Read/ReadVariableOp", var_op.name))
             update_consumers(get_consumers(op), op.outputs[0], read_out)
Пример #3
0
    def _get_aggregated_sparse_grad(self, graph_item, var_op, grad,
                                    reduce_to_device, BFTaggregator):
        indices_op_name = strip_replica_prefix(get_op_name(grad.indices.name))
        values_op_name = strip_replica_prefix(get_op_name(grad.values.name))
        dense_shape_op_name = strip_replica_prefix(
            get_op_name(grad.dense_shape.name))

        indexed_slices_grads = []
        for i in range(self.num_replicas):
            indices_op = graph_item.graph.get_operation_by_name(
                ops.prepend_name_scope(indices_op_name, replica_prefix(i)))
            values_op = graph_item.graph.get_operation_by_name(
                ops.prepend_name_scope(values_op_name, replica_prefix(i)))
            dense_shape_op = graph_item.graph.get_operation_by_name(
                ops.prepend_name_scope(dense_shape_op_name, replica_prefix(i)))
            indexed_slices_grads.append(
                ops.IndexedSlices(
                    values_op.outputs[utils.get_index_from_tensor_name(
                        grad.values.name)],
                    indices_op.outputs[utils.get_index_from_tensor_name(
                        grad.indices.name)],
                    dense_shape_op.outputs[utils.get_index_from_tensor_name(
                        grad.dense_shape.name)]))

        return self._aggregate_sparse_gradients(var_op, reduce_to_device,
                                                indexed_slices_grads,
                                                values_op_name)
Пример #4
0
    def _batch_prepend_name_scope(self, to_rename, new_name_scope):
        """
        Construct a new GraphItem with all ops in `to_rename` under `new_name_scope`.

        Args:
            to_rename (set): Collection of ops to rename
            new_name_scope (str): The new name scope to prepend to all ops

        Returns:
            GraphItem
        """
        og_graph_def = self.graph_item.graph.as_graph_def()
        new_graph_def = graph_pb2.GraphDef()
        new_graph_def.library.Clear()
        new_graph_def.library.CopyFrom(og_graph_def.library)
        control_flow_contexts = {}

        for node in og_graph_def.node:
            op = self.graph_item.graph.get_operation_by_name(node.name)

            # Save control flow context to add it back later
            # Since it is not automatically set based on the attr's in the graph def
            ctx = op._get_control_flow_context()
            if ctx:
                control_flow_contexts[op.name] = ctx

            if op in to_rename:
                node.name = ops.prepend_name_scope(node.name, new_name_scope)

            # Rename inputs
            for idx, input_name in enumerate(node.input):
                input_op = self.graph_item.graph.get_operation_by_name(
                    get_op_name(input_name))
                if input_op in to_rename:
                    node.input[idx] = ops.prepend_name_scope(
                        input_name, new_name_scope)

            # Fix colocation
            for idx, s in enumerate(node.attr['_class'].list.s):
                name = s[len(COLOCATION_PREFIX):].decode('utf-8')
                if self.graph_item.graph.get_operation_by_name(
                        name) in to_rename:
                    node.attr['_class'].list.s[idx] = (
                        COLOCATION_PREFIX +
                        as_bytes(ops.prepend_name_scope(name, new_name_scope)))

            new_graph_def.node.append(node)

        # Re-add control flow contexts
        new_graph_item = GraphItem(graph_def=new_graph_def)
        for op in new_graph_item.graph.get_operations():
            if op.name in control_flow_contexts:
                op._set_control_flow_context(control_flow_contexts[op.name])

        return new_graph_item
Пример #5
0
def test_parse_name_scope():
    with ops.Graph().as_default():
        name_scope = 'name_scope/child_name_scope'
        a = constant_op.constant(5)
        new_name = ops.prepend_name_scope(a.name, name_scope)
        assert new_name == 'name_scope/child_name_scope/Const:0'
        assert name_scope == utils.parse_name_scope(new_name)
        assert '' == utils.parse_name_scope(a.name)

        with ops.control_dependencies([no_op(name='my_op')]):
            b = constant_op.constant(6)
        name_scope = 'name_scope'
        new_name = ops.prepend_name_scope(b.op.node_def.input[0], name_scope)
        assert new_name == '^name_scope/my_op'
        assert name_scope == utils.parse_name_scope(new_name)
Пример #6
0
    def _get_aggregated_dense_grad(self, graph_item, grad_name,
                                   reduce_to_device, BFTaggregator):
        grad_op_name = strip_replica_prefix(get_op_name(grad_name))
        output_idx = get_index_from_tensor_name(grad_name)
        grad_ops = [
            graph_item.graph.get_operation_by_name(
                ops.prepend_name_scope(grad_op_name, replica_prefix(i)))
            for i in range(self.num_replicas)
        ]

        # Aggregate gradients on `reduce_to_device` (usually CPU)
        with ops.device(reduce_to_device):
            #print("@@@@@@@@@@@@@@",[grad_op.outputs[output_idx] for grad_op in grad_ops])
            '''
            grad_sum_op_name = ops.prepend_name_scope(grad_op_name, u"%sAdd" % AUTODIST_PREFIX)
            grad_sum = math_ops.add_n([grad_op.outputs[output_idx] for grad_op in grad_ops], name=grad_sum_op_name)
            grad_avg_op_name = ops.prepend_name_scope(grad_op_name, u"%sDiv" % AUTODIST_PREFIX)
            grad_avg = math_ops.realdiv(grad_sum, self.num_replicas, name=grad_avg_op_name)
            '''

            # BFT Aggregator
            gradients = [grad_op.outputs[output_idx] for grad_op in grad_ops]
            grad_avg = BFTaggregator.aggregate(gradients)

            #print("$$$$$$$$$$$$$$",grad_avg)

        return grad_avg
Пример #7
0
    def between_graph_apply(self, graph_item, var_name):
        """
        Apply between-graph synchronization to the target ops in the graph.

        Args:
            graph_item: The current graph.
            var_name: the variable to be synchronized.

        Returns:
            graph_item.GraphItem: updated graph item.
        """
        if not self._sync:
            return graph_item
        item = graph_item
        # here the variable on replica:0 has been shared, so the original var_name won't work
        var_op_name = ops.prepend_name_scope(get_op_name(var_name),
                                             replica_prefix(0))
        gradient, target, update_op = item.var_op_name_to_grad_info[
            var_op_name]
        with item.graph.as_default():
            proxy = self._create_proxy(
                item, gradient, target) if self._local_replication else None
            if proxy:
                proxy.update_colocation_group(item.get_colocation_op)
            with item.graph.name_scope(self._BETWEEN_GRAPH_APPLY_SCOPE):
                self._var_op_to_agg_grad, self._var_op_to_accum_apply_op = \
                    self._get_accumulation_ops(item, gradient, target,
                                               1 if self._staleness > 0 else self.num_workers)
                self.add_sync_op(item, update_op, proxy)
            item.graph._names_in_use.pop(self._BETWEEN_GRAPH_APPLY_SCOPE)
        return item
Пример #8
0
    def in_graph_apply(self, graph_item, var_name):
        """
        Perform in-graph synchronization based on AllReduce and TensorFlow Collective Ops.

        Note that collective ops now only supports dense tensors.

        Args:
            graph_item (graph_item.GraphItem): the graph_item to be distributed
            var_name (str): the corresponded variable name

        Returns:
            graph_item.GraphItem: The new graph
        """
        # Skip allreduce synchronizer when rank <= 1
        if self.num_replicas * self.num_workers <= 1:
            return graph_item

        item = graph_item
        var_op_name = get_op_name(var_name)

        # Throw an error if the variable is sparse
        master_op_name = ops.prepend_name_scope(var_op_name, replica_prefix(0))
        graph_item.updated = True
        grad, _, _ = graph_item.var_op_name_to_grad_info_v2[master_op_name]
        graph_item.var_queried.append(master_op_name)
        with item.graph.as_default():
            self._share_initializer(item, var_op_name, master_replica=0)
            if isinstance(grad, ops.IndexedSlices):
                self._collect_sparse_gradients(item, var_op_name)
            else:
                self._collect_dense_gradients(item, var_op_name)
        return item
Пример #9
0
    def _collect_dense_gradients(self, graph_item, var_op_name):
        """Append collective ops after the gradient is calculated."""
        if self.num_replicas * self.num_workers <= 1:
            raise ValueError(
                'CollectiveOps requires collective group size > 1')

        compressors = defaultdict(
            lambda: Compressor.create(self._compressor_type, var_op_name))

        conf = CollectiveOpsConfig()
        conf.group_size = len(self.all_canonical_replica_devices)
        conf.group_key = get_collective_keys().get_group_key(
            self.all_canonical_replica_devices)
        conf.instance_key = get_collective_keys().get_instance_key(var_op_name)
        conf.merge_op = 'Add'
        conf.final_op = 'Div'
        if self._spec:
            setattr(conf, 'communication_hint', self._spec)

        for i in range(0, self.num_replicas):
            op_name = ops.prepend_name_scope(var_op_name, replica_prefix(i))
            graph_item.updated = True
            grad, _, _ = graph_item.var_op_name_to_grad_info_v2[op_name]
            # TODO (Tairui): (3) Merge of reduction for performance
            grad_consumers = get_consumers(
                grad.op)  # this line must happen before the reduction

            # "\/" is added for name scope reuse
            with ops.name_scope(
                    replica_prefix(i) +
                    "/collective-group-{}/".format(self._group)):
                with ops.colocate_with(grad.op):
                    reduced_grad = compressors[i].reduce(grad, conf)
            update_consumers(grad_consumers, grad, reduced_grad)
Пример #10
0
    def in_graph_apply(self, graph_item, var_name):
        """
        Apply in-graph ps synchronization.

        Args:
            graph_item: the old graph item
            var_name: the variable name w/o replica prefix

        Returns:
            graph_item.GraphItem

        """
        item = graph_item
        var_op_name = get_op_name(var_name)
        master_replica_index = 0

        with item.graph.as_default():
            self._prune_control_dependencies(
                item, var_op_name, master_replica=master_replica_index)
            self._share_variable(item,
                                 var_op_name,
                                 master_replica=master_replica_index)
            master_var_name = ops.prepend_name_scope(
                var_name, replica_prefix(master_replica_index))
            master_var_op_name = get_op_name(master_var_name)
            item.updated = True
            grad, target, update_op = item.var_op_name_to_grad_info_v2[
                master_var_op_name]
            item.var_queried.append(master_var_op_name)
            agg_grad = self._aggregate_gradients(item,
                                                 old_update_op=update_op,
                                                 old_grad=grad,
                                                 old_target=target)

        # update grad_target_pair and variable info
        for i in range(self.num_replicas):
            var_name_to_remove = ops.prepend_name_scope(
                var_name, replica_prefix(i))
            item.pop_gradient_info(var_name=var_name_to_remove)
            if i != master_replica_index:
                item.info.pop_variable(var_name=var_name_to_remove)
        item.extend_gradient_info(
            grads=[agg_grad],
            targets=[item.graph.get_tensor_by_name(master_var_name)])
        # TODO(Hao): Prune the graph to use unnecessary nodes
        return item
Пример #11
0
    def _aggregate_sparse_gradients(self, var_op, reduce_to_device,
                                    indexed_slices_grads, values_op_name):
        with ops.device(reduce_to_device):
            grad_accum_op_name = ops.prepend_name_scope(
                values_op_name, u"%sAccum" % AUTODIST_PREFIX)
            grad_accum = data_flow_ops.SparseConditionalAccumulator(
                dtype=indexed_slices_grads[0].values.dtype,
                shape=var_op.outputs[0].shape,
                shared_name=grad_accum_op_name,
                name=grad_accum_op_name)
            accum_apply_ops = [
                grad_accum.apply_indexed_slices_grad(
                    indexed_slices_grads[i],
                    MAX_INT64,
                    name=ops.prepend_name_scope(
                        values_op_name, u"%s-Accum-Apply" % replica_prefix(i)))
                for i in range(self.num_replicas)
            ]
            take_grad_op_name = ops.prepend_name_scope(
                values_op_name, u"%sTake-Grad" % AUTODIST_PREFIX)
            with ops.control_dependencies(accum_apply_ops):
                take_grad = grad_accum.take_indexed_slices_grad(
                    self.num_replicas, name=take_grad_op_name)

            new_indices = take_grad.indices
            new_values = take_grad.values
            new_dense_shape = take_grad.dense_shape
            if indexed_slices_grads[0].indices.dtype != new_indices.dtype:
                new_indices = math_ops.cast(
                    new_indices,
                    indexed_slices_grads[0].indices.dtype,
                    name=ops.prepend_name_scope(
                        values_op_name,
                        u"%sTake-Grad-Cast-Indices" % AUTODIST_PREFIX))
            if indexed_slices_grads[
                    0].dense_shape.dtype != new_dense_shape.dtype:
                new_dense_shape = math_ops.cast(
                    new_dense_shape,
                    indexed_slices_grads[0].dense_shape.dtype,
                    name=ops.prepend_name_scope(
                        values_op_name,
                        u"%sTake-Grad-Cast-Shape" % AUTODIST_PREFIX))
        return ops.IndexedSlices(new_values, new_indices, new_dense_shape)
Пример #12
0
    def _update_gradient_consumers(new_graph_item, consumer_ops,
                                   control_consumer_ops, old_tensor_name,
                                   new_tensor):
        """Make gradient's consumers consume the aggregated gradient instead of the original one of replica_0."""
        # Get the original tensor (the one from replica 0) to replace
        old_op_name = strip_replica_prefix(get_op_name(old_tensor_name))
        replica_0_op_name = ops.prepend_name_scope(old_op_name,
                                                   replica_prefix(0))
        replica_0_op = new_graph_item.graph.get_operation_by_name(
            replica_0_op_name)
        output_idx = get_index_from_tensor_name(old_tensor_name)
        replica_0_tensor = replica_0_op.outputs[output_idx]

        update_consumers(consumer_ops, replica_0_tensor, new_tensor)
        update_control_consumers(control_consumer_ops, replica_0_tensor.op,
                                 new_tensor.op)
Пример #13
0
 def _collect_sparse_gradients(self, graph_item, var_op_name):
     """Append collective ops after the gradient is calculated."""
     if self.num_workers > 1 and not ENV.AUTODIST_INTERNAL_TF.value:
         raise NotImplementedError(
             'Currently the collective NCCL AllGather is not supported in TensorFlow release.'
             'Please choose another strategy.')
     conf = {}
     if self._spec:
         conf = {'communication_hint': self._spec}
     if self._compressor_type:
         logging.warning(
             'AllGather currently does not support AutoDist compressor so it skips.'
         )
     if self.num_replicas * self.num_workers <= 1:
         raise ValueError(
             'CollectiveOps requires collective group size > 1')
     for i in range(0, self.num_replicas):
         op_name = ops.prepend_name_scope(var_op_name, replica_prefix(i))
         graph_item.updated = True
         grad, _, _ = graph_item.var_op_name_to_grad_info_v2[op_name]
         # TODO (Tairui): (3) Merge of reduction for performance
         indices_c_ops = grad.indices.consumers()
         indices_cc_ops = get_control_consumers(grad.indices.op)
         values_c_ops = grad.values.consumers()
         values_cc_ops = get_control_consumers(grad.values.op)
         with ops.name_scope(replica_prefix(i)):
             with ops.colocate_with(grad.indices.op):
                 new_indices = collective_ops.all_gather(
                     grad.indices, self.num_replicas * self.num_workers,
                     get_collective_keys().get_group_key(
                         self.all_canonical_replica_devices),
                     get_collective_keys().get_instance_key(var_op_name +
                                                            '-indices'),
                     **conf)
             with ops.colocate_with(grad.values.op):
                 new_values = collective_ops.all_gather(
                     grad.values, self.num_replicas * self.num_workers,
                     get_collective_keys().get_group_key(
                         self.all_canonical_replica_devices),
                     get_collective_keys().get_instance_key(var_op_name +
                                                            '-values'),
                     **conf)
         update_consumers(indices_c_ops, grad.indices, new_indices)
         update_control_consumers(indices_cc_ops, grad.indices.op,
                                  new_indices.op)
         update_consumers(values_c_ops, grad.values, new_values)
         update_control_consumers(values_cc_ops, grad.values.op, new_values)
Пример #14
0
    def _build_proxy_on(self, destination_device):
        """
        Build a proxy of the original variable on `destination_device`.

        Args:
            destination_device (DeviceSpecV2): the destination device where the proxy is on.
        """
        is_gpu = destination_device.device_type.upper(
        ) == 'GPU' if destination_device.device_type else False
        prefix = replica_prefix(destination_device.device_index
                                ) if is_gpu else replica_prefix('CPU')
        with ops.device(destination_device):
            proxy_var = variable_scope.get_variable(
                ops.prepend_name_scope(self._this_op.name, prefix),
                dtype=self._dtype,
                initializer=self._initial_value,
                trainable=False)
        self._graph_item.info.update_variables(
            [proxy_var], replace=False)  # Should we update graph_item.info?
        self._proxy_vars.append(proxy_var)
        self._proxy_var_init_ops.append(
            proxy_var.assign(get_read_var_tensor(self._this_op)))
        self._mirror_all_read_var_ops()
        self._update_all_consumers()
Пример #15
0
    def _prune_control_dependencies(self,
                                    graph_item,
                                    var_op_name,
                                    master_replica=0):
        """
        Prune the control dependencies between the train_op on non-master replica and update op.

        Since the replicator will replicate the entire graph, the update op on non-master replica
        will also be replicated. If the train_op on non-master replica is fetched (which is the case
        in our current feed-fetch remap implementation), it will trigger those update ops and result
        in an unnecessary update over the trainable variables.
        This function prunes the control dependencies between train_op and any variable that bases on
        a PS syncer to avoid this situation.
        """
        for i in range(self.num_replicas):
            if i == master_replica:
                continue
            this_var_op_name = ops.prepend_name_scope(var_op_name,
                                                      replica_prefix(i))
            _, _, update_op = graph_item.var_op_name_to_grad_info[
                this_var_op_name]
            source_op = self._get_optimizer_source_op(update_op)
            remove_from_control_consumers(get_control_consumers(source_op),
                                          source_op)
Пример #16
0
    def replicate(self, graph_item):
        """
        Replicate the entire graph as many times as num_replica.

        Args:
            graph_item: the original graph item

        Returns: The new graph item
        """
        item = GraphItem(graph=ops.Graph())
        fwd_ctx, bwd_ctx = self._collect_while_context(graph_item.graph)
        with item.graph.as_default():
            gdef = graph_item.graph.as_graph_def()
            for i in range(self._num_local_replicas):
                # Replicate ops
                with ops.device(self._replica_device_placer(replica_id=i)):
                    import_graph_def(gdef, name=replica_prefix(i))

                # Replicate while_loop context (control_flow) if needed.
                # The order matters -- We must replicate bwd context first, then forward context.
                # TODO(Zeya): To handle cases when there are nested while loops, in which we must replicate
                #  parent context first and then child context. See:
                #  https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/control_flow_ops.py#L938
                if bwd_ctx:
                    for ctx in bwd_ctx:
                        _ = WhileContext(context_def=ctx.to_proto(), grad_state=ctx._grad_state,
                                         import_scope=replica_prefix(i))
                if fwd_ctx:
                    for ctx in fwd_ctx:
                        _ = WhileContext(context_def=ctx.to_proto(), grad_state=ctx._grad_state,
                                         import_scope=replica_prefix(i))

            # update saver
            master_replica = 0
            if graph_item.info.savers:
                item.info.update_savers(
                    [Saver.from_proto(proto, import_scope=replica_prefix(master_replica)).to_proto()
                        for proto in graph_item.info.savers],
                    replace=False
                )

            # update gradient info
            for i in range(self._num_local_replicas):
                for g_name, t_name in graph_item.grad_target_name_pairs.items():
                    if isinstance(g_name, tuple):
                        new_g_name = (
                            ops.prepend_name_scope(g_name[0], replica_prefix(i)),
                            ops.prepend_name_scope(g_name[1], replica_prefix(i)),
                            ops.prepend_name_scope(g_name[2], replica_prefix(i)))
                    else:
                        new_g_name = ops.prepend_name_scope(g_name, replica_prefix(i))
                    new_t_name = ops.prepend_name_scope(t_name, replica_prefix(i))
                    item.extend_gradient_info_by_names(
                        grads=[new_g_name],
                        targets=[new_t_name]
                    )
                item.info.update_variables(
                    [_from_proto_fn(proto, import_scope=replica_prefix(i)).to_proto()
                        for proto in graph_item.info.variables],
                    replace=False
                )
                item.info.update_table_initializers(
                    [ops.prepend_name_scope(tb_init, replica_prefix(i))
                        for tb_init in graph_item.info.table_initializers],
                    replace=False
                )
        return item
Пример #17
0
def test_strip_replica_prefix():
    for name in ['my_op', '^my_op', 'my_tensor:0']:
        new_name = ops.prepend_name_scope(name, replica_prefix(12))
        assert strip_replica_prefix(new_name) == name
Пример #18
0
    def _share_variable(self, graph_item, var_op_name, master_replica=0):
        """
        Share the variable on the replica = `master_replica` (default to 0).

        Update inputs of consumers of the variable on replica > 0 to variable on replica=`master_replica`.

        Args:
            graph_item: the old graph item
            var_op_name: the name of the variable op of the variable to be shared
            master_replica: the index of master replica (default to 0)
        """
        for i in range(0, self.num_replicas):
            if i == master_replica:
                continue
            this_var_op_name = ops.prepend_name_scope(var_op_name,
                                                      replica_prefix(i))
            this_var_op = graph_item.graph.get_operation_by_name(
                this_var_op_name)

            # Get all read variable ops to this replica variable
            read_var_ops = get_read_var_ops(this_var_op)

            # Get all consumers of its VarhandleOp,
            # excluding ReadVariableOps and those not in its variable scope
            handle_consumers = set(get_consumers(this_var_op))
            handle_consumers.difference_update(set(read_var_ops))
            handle_consumers.difference_update({
                con
                for con in handle_consumers
                if con.name.startswith(this_var_op_name + '/')
            })
            # We exclude the `update_op` when updating the consumers on the shared variables.
            # Because i) sharing variable indicates sharing its stateful ops correspondingly
            # (so it is ok to remove stateful ops in none-master replica but we just disconnect it)
            # ii) A variable cannot correspond to more than one update ops for now.
            handle_consumers.difference_update(set(graph_item.all_update_ops))

            # update the consumers of all read variable ops to use the read variable ops of replica=master_replica
            for read_var_op in read_var_ops:
                new_read_var_op_name = ops.prepend_name_scope(
                    ops.strip_name_scope(read_var_op.name, replica_prefix(i)),
                    replica_prefix(master_replica))
                new_read_var_op = graph_item.graph.get_operation_by_name(
                    new_read_var_op_name)
                consumers = get_consumers(read_var_op)
                update_consumers(consumers, read_var_op.outputs[0],
                                 new_read_var_op.outputs[0])
                update_colocation_group(consumers, read_var_op,
                                        new_read_var_op)

            # update the consumers of VarhandleOp to use the handle on replica=master_replica
            new_handle_op_name = ops.prepend_name_scope(
                ops.strip_name_scope(this_var_op_name, replica_prefix(i)),
                replica_prefix(master_replica))
            new_handle_op = graph_item.graph.get_operation_by_name(
                new_handle_op_name)
            handle_consumers = list(handle_consumers)
            update_consumers(handle_consumers, this_var_op.outputs[0],
                             new_handle_op.outputs[0])
            update_colocation_group(handle_consumers, this_var_op,
                                    new_handle_op)