def _get_accum_apply_and_agg_grad(var_op, grad, indices, dense_shape):
     if indices is None:
         tensor = variable_utils.get_read_var_tensor(var_op)
         grad_accum = data_flow_ops.ConditionalAccumulator(
             grad.dtype,
             shape=tensor.get_shape(),
             shared_name=var_op.name + "/grad_accum")
         # Get a copy of consumers list before creating accum_apply_op
         grad_consumers = list(grad.consumers())
         accum_apply_op = grad_accum.apply_grad(grad,
                                                local_step=MAX_INT64,
                                                name=grad.op.name +
                                                '_accum_apply_grad')
         agg_grad = grad_accum.take_grad(num_accum_required,
                                         name=var_op.name +
                                         '_take_grad')
         update_consumers(grad_consumers, grad, agg_grad)
         update_control_consumers(get_control_consumers(grad.op),
                                  grad.op, agg_grad.op)
     else:
         grad_indexed_slices = ops.IndexedSlices(
             values=grad, indices=indices, dense_shape=dense_shape)
         grad_accum = data_flow_ops.SparseConditionalAccumulator(
             grad.dtype,
             shape=grad.shape,
             shared_name=var_op.name + "/grad_accum")
         # Get a copy of consumers list before creating accum_apply_op
         indices_consumers = list(indices.consumers())
         grad_consumers = list(grad.consumers())
         accum_apply_op = grad_accum.apply_indexed_slices_grad(
             grad_indexed_slices,
             local_step=MAX_INT64,
             name=grad.op.name + '_accum_apply_grad')
         agg_grad = grad_accum.take_indexed_slices_grad(
             num_accum_required, name=var_op.name + '_take_grad')
         agg_indices = agg_grad.indices
         if indices.dtype != agg_grad.indices.dtype:
             agg_indices = math_ops.cast(agg_grad.indices,
                                         indices.dtype)
         agg_grad = ops.IndexedSlices(values=agg_grad.values,
                                      indices=agg_indices,
                                      dense_shape=agg_grad.dense_shape)
         assert isinstance(agg_grad, ops.IndexedSlices)
         update_consumers(indices_consumers, indices, agg_grad.indices)
         update_consumers(grad_consumers, grad, agg_grad.values)
         update_control_consumers(get_control_consumers(indices.op),
                                  indices.op, agg_grad.indices.op)
         update_control_consumers(get_control_consumers(grad.op),
                                  grad.op, agg_grad.values.op)
     return accum_apply_op, agg_grad
Esempio n. 2
0
 def _collect_sparse_gradients(self, graph_item, var_op_name):
     """Append collective ops after the gradient is calculated."""
     if self.num_workers > 1 and not ENV.AUTODIST_INTERNAL_TF.value:
         raise NotImplementedError(
             'Currently the collective NCCL AllGather is not supported in TensorFlow release.'
             'Please choose another strategy.')
     conf = {}
     if self._spec:
         conf = {'communication_hint': self._spec}
     if self._compressor_type:
         logging.warning(
             'AllGather currently does not support AutoDist compressor so it skips.'
         )
     if self.num_replicas * self.num_workers <= 1:
         raise ValueError(
             'CollectiveOps requires collective group size > 1')
     for i in range(0, self.num_replicas):
         op_name = ops.prepend_name_scope(var_op_name, replica_prefix(i))
         graph_item.updated = True
         grad, _, _ = graph_item.var_op_name_to_grad_info_v2[op_name]
         # TODO (Tairui): (3) Merge of reduction for performance
         indices_c_ops = grad.indices.consumers()
         indices_cc_ops = get_control_consumers(grad.indices.op)
         values_c_ops = grad.values.consumers()
         values_cc_ops = get_control_consumers(grad.values.op)
         with ops.name_scope(replica_prefix(i)):
             with ops.colocate_with(grad.indices.op):
                 new_indices = collective_ops.all_gather(
                     grad.indices, self.num_replicas * self.num_workers,
                     get_collective_keys().get_group_key(
                         self.all_canonical_replica_devices),
                     get_collective_keys().get_instance_key(var_op_name +
                                                            '-indices'),
                     **conf)
             with ops.colocate_with(grad.values.op):
                 new_values = collective_ops.all_gather(
                     grad.values, self.num_replicas * self.num_workers,
                     get_collective_keys().get_group_key(
                         self.all_canonical_replica_devices),
                     get_collective_keys().get_instance_key(var_op_name +
                                                            '-values'),
                     **conf)
         update_consumers(indices_c_ops, grad.indices, new_indices)
         update_control_consumers(indices_cc_ops, grad.indices.op,
                                  new_indices.op)
         update_consumers(values_c_ops, grad.values, new_values)
         update_control_consumers(values_cc_ops, grad.values.op, new_values)
    def _get_optimizer_source_op(update_op):
        """
        Identify the additional no_op between update_op and train_op if it exists (for certain optimizers).

        Args:
            update_op

        Returns:
            source_op: the no_op if existed, otherwise the update_op itself.
        """
        group_deps = [
            op for op in get_control_consumers(update_op) if 'Adam' in op.name
            and 'group_deps' in op.name and op.type == 'NoOp'
        ]
        source_op = group_deps[0] if group_deps else update_op
        return source_op
    def _prune_control_dependencies(self,
                                    graph_item,
                                    var_op_name,
                                    master_replica=0):
        """
        Prune the control dependencies between the train_op on non-master replica and update op.

        Since the replicator will replicate the entire graph, the update op on non-master replica
        will also be replicated. If the train_op on non-master replica is fetched (which is the case
        in our current feed-fetch remap implementation), it will trigger those update ops and result
        in an unnecessary update over the trainable variables.
        This function prunes the control dependencies between train_op and any variable that bases on
        a PS syncer to avoid this situation.
        """
        for i in range(self.num_replicas):
            if i == master_replica:
                continue
            this_var_op_name = ops.prepend_name_scope(var_op_name,
                                                      replica_prefix(i))
            _, _, update_op = graph_item.var_op_name_to_grad_info[
                this_var_op_name]
            source_op = self._get_optimizer_source_op(update_op)
            remove_from_control_consumers(get_control_consumers(source_op),
                                          source_op)
    def add_sync_op(self, graph_item, var_update_op, variable_replicator=None):
        """
        Adds additional ops needed for synchronous distributed training into current graph.

        Main purpose of additional ops are:
        1. Initialization
        2. Synchronization
        3. Gradient aggregation

        Args:
            graph_item (graph_item.GraphItem): the graph
            var_update_op: The op
            variable_replicator: The dictionary of master variable op name
                -> list of replicated variables, could be None

        Returns:
            None
        """
        this_worker_cpu = device_spec.DeviceSpecV2.from_string(
            self.worker_device)
        this_worker_cpu = this_worker_cpu.replace(device_type='CPU',
                                                  device_index=0)

        var_op = var_update_op.inputs[UPDATE_OP_VAR_POS].op
        is_trainable = var_op in graph_item.trainable_var_op_to_var
        source_op = self._get_optimizer_source_op(var_update_op)
        cc = get_control_consumers(source_op)

        with ops.device(var_op.device):
            if self._staleness == 0:
                queue_ops = self._get_queue_ops(var_update_op, source_op,
                                                self.is_chief, is_trainable)
            elif self._staleness > 0:
                queue_ops = self._get_queue_ops_stale(var_update_op, source_op,
                                                      self.is_chief,
                                                      is_trainable)
            else:
                raise ValueError(
                    "staleness should be greater than or equal to 0.")

            # Only dense trainable variables are replicated locally
            if variable_replicator:
                mirror_variable_update_ops = variable_replicator.get_all_update_ops(
                    queue_ops, worker_device=this_worker_cpu)
                with ops.device(this_worker_cpu):
                    finish_op = control_flow_ops.group(
                        *mirror_variable_update_ops)
            else:
                finish_op = control_flow_ops.group(*queue_ops)

        # Place computation ops of aggregated gradients on PS
        # Note that even though this is doing a graph traversal, it is called in such a way that it
        # only traverses from a gradient aggregator op to a gradient application op (or vice versa) --
        # these corresponding ops should always be adjacent in the graph.
        self._place_post_grad_agg_ops(
            device_spec.DeviceSpecV2.from_string(self.target_device),
            self._var_op_to_agg_grad,
            {var_op: var_update_op} if is_trainable else {})

        # Replace the control input of train_op to be finish_op
        # Note(Hao): this cc is stale, i.e. cc \subset get_control_consumers(source_op)
        update_control_consumers(cc, source_op, finish_op)