def _get_aggregated_sparse_grad(self, graph_item, var_op, grad,
                                    reduce_to_device, BFTaggregator):
        indices_op_name = strip_replica_prefix(get_op_name(grad.indices.name))
        values_op_name = strip_replica_prefix(get_op_name(grad.values.name))
        dense_shape_op_name = strip_replica_prefix(
            get_op_name(grad.dense_shape.name))

        indexed_slices_grads = []
        for i in range(self.num_replicas):
            indices_op = graph_item.graph.get_operation_by_name(
                ops.prepend_name_scope(indices_op_name, replica_prefix(i)))
            values_op = graph_item.graph.get_operation_by_name(
                ops.prepend_name_scope(values_op_name, replica_prefix(i)))
            dense_shape_op = graph_item.graph.get_operation_by_name(
                ops.prepend_name_scope(dense_shape_op_name, replica_prefix(i)))
            indexed_slices_grads.append(
                ops.IndexedSlices(
                    values_op.outputs[utils.get_index_from_tensor_name(
                        grad.values.name)],
                    indices_op.outputs[utils.get_index_from_tensor_name(
                        grad.indices.name)],
                    dense_shape_op.outputs[utils.get_index_from_tensor_name(
                        grad.dense_shape.name)]))

        return self._aggregate_sparse_gradients(var_op, reduce_to_device,
                                                indexed_slices_grads,
                                                values_op_name)
Exemple #2
0
    def build(self, graph_item, resource_spec):
        """Generate the strategy."""
        expr = Strategy()

        # For each variable, generate variable synchronizer config
        expr.graph_config.replicas.extend(
            [k for k, v in resource_spec.gpu_devices])
        reduction_device_names = [k for k, _ in resource_spec.cpu_devices]
        self.loads = {ps: 0.0 for ps in reduction_device_names}

        # Generate node config
        node_config = []
        for idx, var in enumerate(graph_item.trainable_var_op_to_var.values()):
            var_op_name = get_op_name(var.name)
            grad, _, _ = graph_item.var_op_name_to_grad_info[var_op_name]
            if isinstance(grad, ops.Tensor):  # this is a dense variable
                group_id = idx // self.chunk_size
                config = self._gen_all_reduce_node_config(var.name,
                                                          group=group_id)
            else:  # sparse updates
                # For Parallax Strategy, all PS vars are sparse so we don't use a proxy.
                # Sparse variables are likely larger, so keeping copies would be costlier,
                # and usually each device only requires a small part of the overall variable.
                config = self._gen_ps_node_config(
                    var,
                    False,  # For Parallax Strategy, all PS vars are sparse which does not need proxy.
                    self._sync,
                    self._staleness)
            node_config.append(config)
        expr.node_config.extend(node_config)

        return expr
Exemple #3
0
    def in_graph_apply(self, graph_item, var_name):
        """
        Perform in-graph synchronization based on AllReduce and TensorFlow Collective Ops.

        Note that collective ops now only supports dense tensors.

        Args:
            graph_item (graph_item.GraphItem): the graph_item to be distributed
            var_name (str): the corresponded variable name

        Returns:
            graph_item.GraphItem: The new graph
        """
        # Skip allreduce synchronizer when rank <= 1
        if self.num_replicas * self.num_workers <= 1:
            return graph_item

        item = graph_item
        var_op_name = get_op_name(var_name)

        # Throw an error if the variable is sparse
        master_op_name = ops.prepend_name_scope(var_op_name, replica_prefix(0))
        graph_item.updated = True
        grad, _, _ = graph_item.var_op_name_to_grad_info_v2[master_op_name]
        graph_item.var_queried.append(master_op_name)
        with item.graph.as_default():
            self._share_initializer(item, var_op_name, master_replica=0)
            if isinstance(grad, ops.IndexedSlices):
                self._collect_sparse_gradients(item, var_op_name)
            else:
                self._collect_dense_gradients(item, var_op_name)
        return item
Exemple #4
0
    def _get_vars_to_partition(self):
        """
        Analyzes the strategy and returns mappings for the vars to partition and the vars to not.

        Returns:
            vars_to_partition (Dict): Mapping of variable names to the tuple of partition_str and reduction devices.
            unpartitioned_vars (Dict): Mapping from variable name to gradient name of unpartitioned vars.
        """
        vars_to_partition = {}
        unpartitioned_vars = {}
        for node in self.node_config:
            partitioner = getattr(node, 'partitioner')
            if partitioner:
                reduction_destinations = []
                for part in node.part_config:
                    synchronizer = getattr(part, part.WhichOneof('synchronizer'))
                    if hasattr(synchronizer, 'reduction_destination'):
                        reduction_destinations.append(synchronizer.reduction_destination)
                    else:
                        reduction_destinations.append('')
                vars_to_partition[node.var_name] = (partitioner, reduction_destinations)
                logging.info("Partitioning variable {} with configuration {}".format(node.var_name, partitioner))
            else:
                grad, _, _ = self.graph_item.var_op_name_to_grad_info[get_op_name(node.var_name)]
                unpartitioned_vars[node.var_name] = grad
        return vars_to_partition, unpartitioned_vars
    def between_graph_apply(self, graph_item, var_name):
        """
        Apply between-graph synchronization to the target ops in the graph.

        Args:
            graph_item: The current graph.
            var_name: the variable to be synchronized.

        Returns:
            graph_item.GraphItem: updated graph item.
        """
        if not self._sync:
            return graph_item
        item = graph_item
        # here the variable on replica:0 has been shared, so the original var_name won't work
        var_op_name = ops.prepend_name_scope(get_op_name(var_name),
                                             replica_prefix(0))
        gradient, target, update_op = item.var_op_name_to_grad_info[
            var_op_name]
        with item.graph.as_default():
            proxy = self._create_proxy(
                item, gradient, target) if self._local_replication else None
            if proxy:
                proxy.update_colocation_group(item.get_colocation_op)
            with item.graph.name_scope(self._BETWEEN_GRAPH_APPLY_SCOPE):
                self._var_op_to_agg_grad, self._var_op_to_accum_apply_op = \
                    self._get_accumulation_ops(item, gradient, target,
                                               1 if self._staleness > 0 else self.num_workers)
                self.add_sync_op(item, update_op, proxy)
            item.graph._names_in_use.pop(self._BETWEEN_GRAPH_APPLY_SCOPE)
        return item
    def _get_aggregated_dense_grad(self, graph_item, grad_name,
                                   reduce_to_device, BFTaggregator):
        grad_op_name = strip_replica_prefix(get_op_name(grad_name))
        output_idx = get_index_from_tensor_name(grad_name)
        grad_ops = [
            graph_item.graph.get_operation_by_name(
                ops.prepend_name_scope(grad_op_name, replica_prefix(i)))
            for i in range(self.num_replicas)
        ]

        # Aggregate gradients on `reduce_to_device` (usually CPU)
        with ops.device(reduce_to_device):
            #print("@@@@@@@@@@@@@@",[grad_op.outputs[output_idx] for grad_op in grad_ops])
            '''
            grad_sum_op_name = ops.prepend_name_scope(grad_op_name, u"%sAdd" % AUTODIST_PREFIX)
            grad_sum = math_ops.add_n([grad_op.outputs[output_idx] for grad_op in grad_ops], name=grad_sum_op_name)
            grad_avg_op_name = ops.prepend_name_scope(grad_op_name, u"%sDiv" % AUTODIST_PREFIX)
            grad_avg = math_ops.realdiv(grad_sum, self.num_replicas, name=grad_avg_op_name)
            '''

            # BFT Aggregator
            gradients = [grad_op.outputs[output_idx] for grad_op in grad_ops]
            grad_avg = BFTaggregator.aggregate(gradients)

            #print("$$$$$$$$$$$$$$",grad_avg)

        return grad_avg
Exemple #7
0
    def in_graph_apply(self, graph_item, var_name):
        """
        Apply in-graph ps synchronization.

        Args:
            graph_item: the old graph item
            var_name: the variable name w/o replica prefix

        Returns:
            graph_item.GraphItem

        """
        item = graph_item
        var_op_name = get_op_name(var_name)
        master_replica_index = 0

        with item.graph.as_default():
            self._prune_control_dependencies(
                item, var_op_name, master_replica=master_replica_index)
            self._share_variable(item,
                                 var_op_name,
                                 master_replica=master_replica_index)
            master_var_name = ops.prepend_name_scope(
                var_name, replica_prefix(master_replica_index))
            master_var_op_name = get_op_name(master_var_name)
            item.updated = True
            grad, target, update_op = item.var_op_name_to_grad_info_v2[
                master_var_op_name]
            item.var_queried.append(master_var_op_name)
            agg_grad = self._aggregate_gradients(item,
                                                 old_update_op=update_op,
                                                 old_grad=grad,
                                                 old_target=target)

        # update grad_target_pair and variable info
        for i in range(self.num_replicas):
            var_name_to_remove = ops.prepend_name_scope(
                var_name, replica_prefix(i))
            item.pop_gradient_info(var_name=var_name_to_remove)
            if i != master_replica_index:
                item.info.pop_variable(var_name=var_name_to_remove)
        item.extend_gradient_info(
            grads=[agg_grad],
            targets=[item.graph.get_tensor_by_name(master_var_name)])
        # TODO(Hao): Prune the graph to use unnecessary nodes
        return item
    def _gen_ps_node_config(self, var):
        """
        Creates a NodeConfig specifying synchronization with Parameter Servers.

        Args:
            var (Variable): The variable to generate a config for.

        Returns:
            Dict: the config dict for the node.
        """
        if (len(self.loads) < 1 and not ENV.AUTODIST_IS_TESTING.val) or \
                any((o.type in CONTROL_FLOW_OPS for o in get_consumers(var.op))):
            # Don't partition if there is only one reduction device or if the variable is connected to control flow
            # For stability, we err on the side of not partitioning over potentially breaking
            num_shards = 1
        else:
            num_shards = self.get_num_shards(var)

        # Determine placement of vars/parts
        sorted_ps = sorted(self.loads, key=self.loads.get)
        if num_shards > len(self.loads):
            # If there's more shards than servers, round-robin in greedy order
            sorted_ps = sorted_ps * ceil(num_shards / len(self.loads))
        min_ps = sorted_ps[0:num_shards]
        for ps in min_ps:
            self.loads[ps] += byte_size_load_fn(var) / num_shards

        # setup node config
        node = strategy_pb2.Strategy.Node()
        node.var_name = var.name

        if num_shards == 1:
            node.PSSynchronizer.reduction_destination = min_ps[0]
            node.PSSynchronizer.local_replication = self._local_proxy_variable
            node.PSSynchronizer.sync = self._sync
            node.PSSynchronizer.staleness = self._staleness
        else:
            # generate the partitioner config
            shape = var.initial_value.shape
            partition_list = [1] * len(var.initial_value.shape)
            partition_axis = 0
            partition_list[partition_axis] = min(
                num_shards, shape.dims[partition_axis].value)
            pc = PartitionerConfig(partition_list=partition_list)
            node.partitioner = pc.partition_str

            for i in range(num_shards):
                part = strategy_pb2.Strategy.Node()
                part.var_name = '{}/part_{}:0'.format(get_op_name(var.name), i)
                part.PSSynchronizer.reduction_destination = min_ps[i]
                part.PSSynchronizer.local_replication = self._local_proxy_variable
                part.PSSynchronizer.sync = self._sync
                part.PSSynchronizer.staleness = self._staleness
                node.part_config.extend([part])
        return node
    def _batch_prepend_name_scope(self, to_rename, new_name_scope):
        """
        Construct a new GraphItem with all ops in `to_rename` under `new_name_scope`.

        Args:
            to_rename (set): Collection of ops to rename
            new_name_scope (str): The new name scope to prepend to all ops

        Returns:
            GraphItem
        """
        og_graph_def = self.graph_item.graph.as_graph_def()
        new_graph_def = graph_pb2.GraphDef()
        new_graph_def.library.Clear()
        new_graph_def.library.CopyFrom(og_graph_def.library)
        control_flow_contexts = {}

        for node in og_graph_def.node:
            op = self.graph_item.graph.get_operation_by_name(node.name)

            # Save control flow context to add it back later
            # Since it is not automatically set based on the attr's in the graph def
            ctx = op._get_control_flow_context()
            if ctx:
                control_flow_contexts[op.name] = ctx

            if op in to_rename:
                node.name = ops.prepend_name_scope(node.name, new_name_scope)

            # Rename inputs
            for idx, input_name in enumerate(node.input):
                input_op = self.graph_item.graph.get_operation_by_name(
                    get_op_name(input_name))
                if input_op in to_rename:
                    node.input[idx] = ops.prepend_name_scope(
                        input_name, new_name_scope)

            # Fix colocation
            for idx, s in enumerate(node.attr['_class'].list.s):
                name = s[len(COLOCATION_PREFIX):].decode('utf-8')
                if self.graph_item.graph.get_operation_by_name(
                        name) in to_rename:
                    node.attr['_class'].list.s[idx] = (
                        COLOCATION_PREFIX +
                        as_bytes(ops.prepend_name_scope(name, new_name_scope)))

            new_graph_def.node.append(node)

        # Re-add control flow contexts
        new_graph_item = GraphItem(graph_def=new_graph_def)
        for op in new_graph_item.graph.get_operations():
            if op.name in control_flow_contexts:
                op._set_control_flow_context(control_flow_contexts[op.name])

        return new_graph_item
Exemple #10
0
    def _gen_node_config(self, var, var_counter):
        """
        Creates a NodeConfig specifying partitioning and synchronization with AllReduce.

        Args:
            var (Variable): The variable to generate a config for.
            var_counter (int): variable counter for collective group assignment.

        Returns:
            Dict: the config dict for the node.
        """
        num_shards = self.get_num_shards(var)

        node = strategy_pb2.Strategy.Node()
        node.var_name = var.name

        if num_shards <= 1:
            node.AllReduceSynchronizer.spec = synchronizers_pb2.AllReduceSynchronizer.Spec.Value(
                "AUTO")
            node.AllReduceSynchronizer.compressor = \
                synchronizers_pb2.AllReduceSynchronizer.Compressor.Value("NoneCompressor")
            # node.AllReduceSynchronizer.compressor = \
            #     synchronizers_pb2.AllReduceSynchronizer.Compressor.Value("PowerSGDCompressor")
            node.AllReduceSynchronizer.group = var_counter // self.chunk_size
            return node, num_shards

        # num_parts > 1 means the variable will be partitioned
        # generate the partitioner config
        shape = var.initial_value.shape
        partition_list = [1] * len(var.initial_value.shape)
        partition_axis = 0
        partition_list[partition_axis] = min(num_shards,
                                             shape.dims[partition_axis].value)
        num_parts = np.prod(partition_list)
        pc = PartitionerConfig(partition_list=partition_list)
        node.partitioner = pc.partition_str
        for i in range(num_parts):
            part = strategy_pb2.Strategy.Node()

            # If part var_name is inconsistent with what TF will create, partitioner kernel will correct it later.
            # Here let's just make it consistent
            part.var_name = '{}/part_{}:0'.format(get_op_name(var.name), i)
            part.AllReduceSynchronizer.spec = synchronizers_pb2.AllReduceSynchronizer.Spec.Value(
                "AUTO")
            part.AllReduceSynchronizer.compressor = \
                synchronizers_pb2.AllReduceSynchronizer.Compressor.Value("NoneCompressor")
            # part.AllReduceSynchronizer.compressor = \
            #     synchronizers_pb2.AllReduceSynchronizer.Compressor.Value("PowerSGDCompressor")
            part.AllReduceSynchronizer.group = (var_counter +
                                                i) // self.chunk_size
            node.part_config.extend([part])
        return node, num_shards
Exemple #11
0
    def _get_ops_to_delete(self, vars_to_partition):
        """
        Get all ops that need to be deleted based on `vars_to_partition`.

        Also keeps the info object up to date.

        Args:
            vars_to_partition (Dict): Mapping of variable names to number of shards for vars to be partitioned.

        Returns:
            Set of ops to be deleted.
        """
        to_delete = set()

        # Mark all ops part of the optimizer for deletion
        update_op_scopes = set()
        top_update_op_scopes = set()
        opt_name = self.graph_item.optimizer_args[0]._name
        for var_op_name, (_, _, update_op) in self.graph_item.var_op_name_to_grad_info.items():
            top_level_scope_opt = update_op.name[:update_op.name.find(opt_name) + len(opt_name)]
            # An optimizer can create all its relevant ops under the top level optimizer scope
            update_op_scopes.add(top_level_scope_opt)
            top_update_op_scopes.add(top_level_scope_opt)
            #   as well as nested optimizer scopes under each variable name scope
            update_op_scopes.add(var_op_name + '/' + opt_name)

        for var_name in vars_to_partition:
            var_op_name = get_op_name(var_name)
            var_op = self.graph_item.graph.get_operation_by_name(var_op_name)
            var = self.graph_item.trainable_var_op_to_var[var_op]
            update_op = self.graph_item.var_op_name_to_grad_info[var_op_name][2]
            consumers = get_consumers(var_op)

            # Mark var and all its consumers for deletion
            consumers_to_delete = {c for c in consumers if c.type in MUTABLE_STATE_OP_DIRECT_CONSUMER_OPS}
            to_delete.update([var_op, update_op], consumers_to_delete)

            # Update GraphItem Info
            self.info.pop_variable(var.name)

        to_delete.update({o for o in self.graph_item.graph.get_operations()
                          if any(o.name.startswith(top_level_scope) for top_level_scope in update_op_scopes)})
        # NOTE: Here we assume the name_scope in saver is the default one.
        to_delete.update({o for o in self.graph_item.graph.get_operations()
                          if o.name.startswith("save/")})
        # If the user uses optimizer.get_gradients, gradients are stored under optimizer_name/gradients.
        # We don't want to delete those.
        # There could be other cases which require this logic to be made more robust, though.
        to_delete = {o for o in to_delete
                     if not any(o.name.startswith(tl_scope + '/gradients/') for tl_scope in update_op_scopes)}
        return to_delete, top_update_op_scopes
    def trainable_var_op_to_var(self):
        """
        Mapping from trainable variable ops (e.g. VarHandleOps) to the Variables.

        Returns:
            Dict
        """
        with self.graph.as_default():
            return {
                self.graph.get_operation_by_name(
                    get_op_name(var_def.variable_name)):
                _from_proto_fn(var_def)
                for var_def in self.info.trainable_variables
            }
Exemple #13
0
    def _update_gradient_consumers(new_graph_item, consumer_ops,
                                   control_consumer_ops, old_tensor_name,
                                   new_tensor):
        """Make gradient's consumers consume the aggregated gradient instead of the original one of replica_0."""
        # Get the original tensor (the one from replica 0) to replace
        old_op_name = strip_replica_prefix(get_op_name(old_tensor_name))
        replica_0_op_name = ops.prepend_name_scope(old_op_name,
                                                   replica_prefix(0))
        replica_0_op = new_graph_item.graph.get_operation_by_name(
            replica_0_op_name)
        output_idx = get_index_from_tensor_name(old_tensor_name)
        replica_0_tensor = replica_0_op.outputs[output_idx]

        update_consumers(consumer_ops, replica_0_tensor, new_tensor)
        update_control_consumers(control_consumer_ops, replica_0_tensor.op,
                                 new_tensor.op)
Exemple #14
0
    def build(self, graph_item, resource_spec):
        """Generate the strategy."""
        expr = Strategy()

        # For each variable, generate variable synchronizer config
        # resouce_spec.gpu_devices = dict_items([('162.105.146.118:GPU:0', <DeviceSpec: 162.105.146.118:GPU:0>), ('162.105.146.118:GPU:1', <DeviceSpec: 162.105.146.118:GPU:1>), ('162.105.146.118:GPU:2', <DeviceSpec: 162.105.146.118:GPU:2>), ('162.105.146.118:GPU:3', <DeviceSpec: 162.105.146.118:GPU:3>), ('162.105.146.118:GPU:4', <DeviceSpec: 162.105.146.118:GPU:4>), ('162.105.146.118:GPU:5', <DeviceSpec: 162.105.146.118:GPU:5>), ('162.105.146.118:GPU:6', <DeviceSpec: 162.105.146.118:GPU:6>), ('162.105.146.118:GPU:7', <DeviceSpec: 162.105.146.118:GPU:7>)]), ('162.105.146.119:GPU:0', <DeviceSpec: 162.105.146.119:GPU:0>), ('162.105.146.119:GPU:1', <DeviceSpec: 162.105.146.119:GPU:1>), ('162.105.146.119:GPU:2', <DeviceSpec: 162.105.146.119:GPU:2>), ('162.105.146.119:GPU:3', <DeviceSpec: 162.105.146.119:GPU:3>), ('162.105.146.119:GPU:4', <DeviceSpec: 162.105.146.119:GPU:4>), ('162.105.146.119:GPU:5', <DeviceSpec: 162.105.146.119:GPU:5>), ('162.105.146.119:GPU:6', <DeviceSpec: 162.105.146.119:GPU:6>), ('162.105.146.119:GPU:7', <DeviceSpec: 162.105.146.119:GPU:7>)])
        gpu_devices = dict()
        for k, v in resource_spec.gpu_devices:
            if '119' not in k:
                gpu_devices[k] = v
        print(resource_spec.gpu_devices)
        #expr.graph_config.replicas.extend([k for k, v in resource_spec.gpu_devices])
        expr.graph_config.replicas.extend([k for k, v in gpu_devices.items()])
        for k, v in resource_spec.node_cpu_devices.items():
            if k not in resource_spec.node_gpu_devices:
                expr.graph_config.replicas.extend(v)
        reduction_device_names = [
            k for k, _ in resource_spec.cpu_devices if '119' in k
        ]
        self.loads = {ps: 0.0 for ps in reduction_device_names}

        # Generate node config
        node_config = []
        for idx, var in enumerate(graph_item.trainable_var_op_to_var.values()):
            var_op_name = get_op_name(var.name)
            grad, _, _ = graph_item.var_op_name_to_grad_info[var_op_name]
            if isinstance(grad, ops.Tensor):  # this is a dense variable
                group_id = idx // self.chunk_size
                config = self._gen_all_reduce_node_config(var.name,
                                                          group=group_id)
            else:  # sparse updates
                # For Parallax Strategy, all PS vars are sparse so we don't use a proxy.
                # Sparse variables are likely larger, so keeping copies would be costlier,
                # and usually each device only requires a small part of the overall variable.
                config = self._gen_ps_node_config(
                    var,
                    False,  # For Parallax Strategy, all PS vars are sparse which does not need proxy.
                    self._sync,
                    self._staleness)
            node_config.append(config)
        expr.node_config.extend(node_config)

        return expr
Exemple #15
0
 def _prune_nodes(self, strategy):
     # Prune the nodes without stateful updates
     s = strategy.copy()
     s.node_config = [n for n in strategy.node_config
                      if get_op_name(n.var_name) in self._graph_item.var_op_name_to_grad_info]
     return s
Exemple #16
0
    def _create_new_vars(self, new_graph_item, vars_to_partition, unpartitioned_vars):
        """
        Constructs new partitioned variables in `new_graph_item`.

        Fixes each var's corresponding gradient by splitting the gradient.

        Fixes the optimizer by just constructing a new one using the new variables.

        Args:
            new_graph_item (GraphItem): The GraphItem in which to construct the new variables and ops.
            vars_to_partition (Dict): Mapping of variable names to number of shards for vars to be partitioned.
            unpartitioned_vars (Dict): Mapping from variable name to gradient name of unpartitioned vars.

        Returns:
            List of new variables.
        """
        new_grads, new_vars = [], []
        partition_config = {}
        with new_graph_item.graph.as_default():
            for var_name, (partition_str, reduction_destinations) in vars_to_partition.items():
                var_op_name = get_op_name(var_name)
                var_op = self.graph_item.graph.get_operation_by_name(var_op_name)
                var = self.graph_item.trainable_var_op_to_var[var_op]
                gradient = self.graph_item.var_op_name_to_grad_info[var_op_name][0]

                # Create partitioned variable and split gradients
                pc = PartitionerConfig(partition_str=partition_str)
                partition_config[var_op_name] = pc

                # Now check compatibility
                if isinstance(gradient, ops.IndexedSlices) and pc.axis != 0:
                    raise ValueError('Embedding variables can only be partitioned along the first axis due to '
                                     'the limitation on the `embedding_lookup_v2` op.')

                initial_value = new_graph_item.graph.get_tensor_by_name(var.initial_value.name)
                # NOTE: to enable the saver, for now we only support partition on the one dimension
                # https://github.com/tensorflow/tensorflow/blob/r2.0/tensorflow/python/ops/variables.py#L2915
                partitioned_var = vs.get_variable(var_op.name, shape=None, initializer=initial_value,
                                                  partitioner=lambda pconf=pc, **unused_kwargs: pconf.partition_list,
                                                  validate_shape=False, use_resource=True)
                var_list = partitioned_var._variable_list

                # Distribute the partitioned variable if they have a PS synchornizer
                # Actually maybe this is not necessary
                for var_slice, device in zip(var_list, reduction_destinations):
                    if device:
                        var_slice.op._set_device_from_string(device)

                if isinstance(gradient, ops.IndexedSlices):
                    # Sparse variable
                    new_grad = ops.IndexedSlices(
                        indices=new_graph_item.graph.get_tensor_by_name(gradient.indices.name),
                        values=new_graph_item.graph.get_tensor_by_name(gradient.values.name),
                        dense_shape=new_graph_item.graph.get_tensor_by_name(gradient.dense_shape.name)
                    )
                    split_grad = self._split_indexed_slices_v2(new_grad, len(var_list), var.shape[0],
                                                               name=f"gradients/splits/sparse_split_{var_op_name}")
                else:
                    new_grad = new_graph_item.graph.get_tensor_by_name(gradient.name)

                    # sometimes new_grad will have polymorphic shape (None), so we use the shape of the original var
                    split_grad = self._split_tensor_v2(new_grad, pc.num_shards, var.shape, pc.axis,
                                                       name=f"gradients/splits/split_{var_op_name}")
                self._handle_read(new_graph_item, var_op, partitioned_var)
                self._update_node_config(var, var_list)

                self.info.update_variables(var_list, replace=False)
                new_vars.extend(var_list)
                new_grads.extend(split_grad)
                new_graph_item.extend_gradient_info(split_grad, var_list)
                new_graph_item.pop_gradient_info(var.name)
        new_graph_item.info = self.info.copy()
        all_vars, all_grads = new_vars, new_grads
        for var, grad in unpartitioned_vars.items():
            if isinstance(grad, ops.IndexedSlices):
                # Sparse variable
                grad = ops.IndexedSlices(
                    indices=new_graph_item.graph.get_tensor_by_name(grad.indices.name),
                    values=new_graph_item.graph.get_tensor_by_name(grad.values.name),
                    dense_shape=new_graph_item.graph.get_tensor_by_name(grad.dense_shape.name)
                )
            else:
                grad = new_graph_item.graph.get_tensor_by_name(grad.name)
            all_grads.append(grad)
            var = new_graph_item.trainable_var_op_to_var[new_graph_item.graph.get_operation_by_name(get_op_name(var))]
            # TensorFlow expects the following to not mess autodist with the tf.distribute
            if (not hasattr(var, "_distribute_strategy")) or var._distribute_strategy:
                setattr(var, "_distribute_strategy", None)
            all_vars.append(var)
        with new_graph_item.graph.as_default():
            optimizer = self.graph_item.optimizer(*self.graph_item.optimizer_args[1:],
                                                  **self.graph_item.optimizer_kwargs)
            _ = optimizer.apply_gradients(zip(all_grads, all_vars))
        return new_vars, partition_config