Exemple #1
0
    def _collect_dense_gradients(self, graph_item, var_op_name):
        """Append collective ops after the gradient is calculated."""
        if self.num_replicas * self.num_workers <= 1:
            raise ValueError(
                'CollectiveOps requires collective group size > 1')

        compressors = defaultdict(
            lambda: Compressor.create(self._compressor_type, var_op_name))

        conf = CollectiveOpsConfig()
        conf.group_size = len(self.all_canonical_replica_devices)
        conf.group_key = get_collective_keys().get_group_key(
            self.all_canonical_replica_devices)
        conf.instance_key = get_collective_keys().get_instance_key(var_op_name)
        conf.merge_op = 'Add'
        conf.final_op = 'Div'
        if self._spec:
            setattr(conf, 'communication_hint', self._spec)

        for i in range(0, self.num_replicas):
            op_name = ops.prepend_name_scope(var_op_name, replica_prefix(i))
            graph_item.updated = True
            grad, _, _ = graph_item.var_op_name_to_grad_info_v2[op_name]
            # TODO (Tairui): (3) Merge of reduction for performance
            grad_consumers = get_consumers(
                grad.op)  # this line must happen before the reduction

            # "\/" is added for name scope reuse
            with ops.name_scope(
                    replica_prefix(i) +
                    "/collective-group-{}/".format(self._group)):
                with ops.colocate_with(grad.op):
                    reduced_grad = compressors[i].reduce(grad, conf)
            update_consumers(grad_consumers, grad, reduced_grad)
Exemple #2
0
    def _update_consumer(self, other, consumer_op):
        """
        Update consumer.

        Args:
            other (Operation): Other resource var op.
            consumer_op (Operation): The new consumer.
        """
        old_read_var_op = self._consumer_to_read_var_op[consumer_op]
        new_read_var_op = self._read_var_ops_mappings[other][old_read_var_op]
        update_consumers([consumer_op],
                         old_tensor=old_read_var_op.outputs[0],
                         new_tensor=new_read_var_op.outputs[0])
Exemple #3
0
    def _update_gradient_consumers(new_graph_item, consumer_ops,
                                   control_consumer_ops, old_tensor_name,
                                   new_tensor):
        """Make gradient's consumers consume the aggregated gradient instead of the original one of replica_0."""
        # Get the original tensor (the one from replica 0) to replace
        old_op_name = strip_replica_prefix(get_op_name(old_tensor_name))
        replica_0_op_name = ops.prepend_name_scope(old_op_name,
                                                   replica_prefix(0))
        replica_0_op = new_graph_item.graph.get_operation_by_name(
            replica_0_op_name)
        output_idx = get_index_from_tensor_name(old_tensor_name)
        replica_0_tensor = replica_0_op.outputs[output_idx]

        update_consumers(consumer_ops, replica_0_tensor, new_tensor)
        update_control_consumers(control_consumer_ops, replica_0_tensor.op,
                                 new_tensor.op)
Exemple #4
0
 def _collect_sparse_gradients(self, graph_item, var_op_name):
     """Append collective ops after the gradient is calculated."""
     if self.num_workers > 1 and not ENV.AUTODIST_INTERNAL_TF.value:
         raise NotImplementedError(
             'Currently the collective NCCL AllGather is not supported in TensorFlow release.'
             'Please choose another strategy.')
     conf = {}
     if self._spec:
         conf = {'communication_hint': self._spec}
     if self._compressor_type:
         logging.warning(
             'AllGather currently does not support AutoDist compressor so it skips.'
         )
     if self.num_replicas * self.num_workers <= 1:
         raise ValueError(
             'CollectiveOps requires collective group size > 1')
     for i in range(0, self.num_replicas):
         op_name = ops.prepend_name_scope(var_op_name, replica_prefix(i))
         graph_item.updated = True
         grad, _, _ = graph_item.var_op_name_to_grad_info_v2[op_name]
         # TODO (Tairui): (3) Merge of reduction for performance
         indices_c_ops = grad.indices.consumers()
         indices_cc_ops = get_control_consumers(grad.indices.op)
         values_c_ops = grad.values.consumers()
         values_cc_ops = get_control_consumers(grad.values.op)
         with ops.name_scope(replica_prefix(i)):
             with ops.colocate_with(grad.indices.op):
                 new_indices = collective_ops.all_gather(
                     grad.indices, self.num_replicas * self.num_workers,
                     get_collective_keys().get_group_key(
                         self.all_canonical_replica_devices),
                     get_collective_keys().get_instance_key(var_op_name +
                                                            '-indices'),
                     **conf)
             with ops.colocate_with(grad.values.op):
                 new_values = collective_ops.all_gather(
                     grad.values, self.num_replicas * self.num_workers,
                     get_collective_keys().get_group_key(
                         self.all_canonical_replica_devices),
                     get_collective_keys().get_instance_key(var_op_name +
                                                            '-values'),
                     **conf)
         update_consumers(indices_c_ops, grad.indices, new_indices)
         update_control_consumers(indices_cc_ops, grad.indices.op,
                                  new_indices.op)
         update_consumers(values_c_ops, grad.values, new_values)
         update_control_consumers(values_cc_ops, grad.values.op, new_values)
Exemple #5
0
 def _handle_read(new_graph_item, var_op, partitioned_var):
     partitioned_var_tensor = partitioned_var.as_tensor()
     for op in get_consumers(var_op):
         op = new_graph_item.graph.get_operation_by_name(
             ops.prepend_name_scope(op.name, AUTODIST_TO_DELETE_SCOPE)
         )
         if op.type == "ResourceGather":
             # Only Resource Variable needs to be taken care of
             #   because ResourceGather consumes resource tensor rather than the tensor of read_var_op
             # Question: Is there any case where the op.type == "ResourceGather"
             #  but we can't use embedding_lookup_v2 to reconstruct the op consuming a partitioned resource
             # The second input to a ResourceGather op is always the indices per the opdef
             emb_lookup = embedding_ops.embedding_lookup_v2(partitioned_var, ids=op.inputs[1])
             update_consumers(get_consumers(op), op.outputs[0], emb_lookup)
         if is_read_var_op(op, version=1):
             # Without our modification, Reference Vars in TF have a read op associated with them.
             # TF can sometimes look for this and expect it to exist (e.g. in graph.as_graph_element)
             # so we add one back to avoid errors.
             # read_out is already the output tensor of the generated identity op
             read_out = array_ops.identity(partitioned_var_tensor,
                                           name=ops.prepend_name_scope("read", var_op.name))
             update_consumers(get_consumers(op), op.outputs[0], read_out)
         elif is_read_var_op(op, version=2):
             read_out = array_ops.identity(partitioned_var_tensor,
                                           name=ops.prepend_name_scope("Read/ReadVariableOp", var_op.name))
             update_consumers(get_consumers(op), op.outputs[0], read_out)
 def _get_accum_apply_and_agg_grad(var_op, grad, indices, dense_shape):
     if indices is None:
         tensor = variable_utils.get_read_var_tensor(var_op)
         grad_accum = data_flow_ops.ConditionalAccumulator(
             grad.dtype,
             shape=tensor.get_shape(),
             shared_name=var_op.name + "/grad_accum")
         # Get a copy of consumers list before creating accum_apply_op
         grad_consumers = list(grad.consumers())
         accum_apply_op = grad_accum.apply_grad(grad,
                                                local_step=MAX_INT64,
                                                name=grad.op.name +
                                                '_accum_apply_grad')
         agg_grad = grad_accum.take_grad(num_accum_required,
                                         name=var_op.name +
                                         '_take_grad')
         update_consumers(grad_consumers, grad, agg_grad)
         update_control_consumers(get_control_consumers(grad.op),
                                  grad.op, agg_grad.op)
     else:
         grad_indexed_slices = ops.IndexedSlices(
             values=grad, indices=indices, dense_shape=dense_shape)
         grad_accum = data_flow_ops.SparseConditionalAccumulator(
             grad.dtype,
             shape=grad.shape,
             shared_name=var_op.name + "/grad_accum")
         # Get a copy of consumers list before creating accum_apply_op
         indices_consumers = list(indices.consumers())
         grad_consumers = list(grad.consumers())
         accum_apply_op = grad_accum.apply_indexed_slices_grad(
             grad_indexed_slices,
             local_step=MAX_INT64,
             name=grad.op.name + '_accum_apply_grad')
         agg_grad = grad_accum.take_indexed_slices_grad(
             num_accum_required, name=var_op.name + '_take_grad')
         agg_indices = agg_grad.indices
         if indices.dtype != agg_grad.indices.dtype:
             agg_indices = math_ops.cast(agg_grad.indices,
                                         indices.dtype)
         agg_grad = ops.IndexedSlices(values=agg_grad.values,
                                      indices=agg_indices,
                                      dense_shape=agg_grad.dense_shape)
         assert isinstance(agg_grad, ops.IndexedSlices)
         update_consumers(indices_consumers, indices, agg_grad.indices)
         update_consumers(grad_consumers, grad, agg_grad.values)
         update_control_consumers(get_control_consumers(indices.op),
                                  indices.op, agg_grad.indices.op)
         update_control_consumers(get_control_consumers(grad.op),
                                  grad.op, agg_grad.values.op)
     return accum_apply_op, agg_grad
    def _share_variable(self, graph_item, var_op_name, master_replica=0):
        """
        Share the variable on the replica = `master_replica` (default to 0).

        Update inputs of consumers of the variable on replica > 0 to variable on replica=`master_replica`.

        Args:
            graph_item: the old graph item
            var_op_name: the name of the variable op of the variable to be shared
            master_replica: the index of master replica (default to 0)
        """
        for i in range(0, self.num_replicas):
            if i == master_replica:
                continue
            this_var_op_name = ops.prepend_name_scope(var_op_name,
                                                      replica_prefix(i))
            this_var_op = graph_item.graph.get_operation_by_name(
                this_var_op_name)

            # Get all read variable ops to this replica variable
            read_var_ops = get_read_var_ops(this_var_op)

            # Get all consumers of its VarhandleOp,
            # excluding ReadVariableOps and those not in its variable scope
            handle_consumers = set(get_consumers(this_var_op))
            handle_consumers.difference_update(set(read_var_ops))
            handle_consumers.difference_update({
                con
                for con in handle_consumers
                if con.name.startswith(this_var_op_name + '/')
            })
            # We exclude the `update_op` when updating the consumers on the shared variables.
            # Because i) sharing variable indicates sharing its stateful ops correspondingly
            # (so it is ok to remove stateful ops in none-master replica but we just disconnect it)
            # ii) A variable cannot correspond to more than one update ops for now.
            handle_consumers.difference_update(set(graph_item.all_update_ops))

            # update the consumers of all read variable ops to use the read variable ops of replica=master_replica
            for read_var_op in read_var_ops:
                new_read_var_op_name = ops.prepend_name_scope(
                    ops.strip_name_scope(read_var_op.name, replica_prefix(i)),
                    replica_prefix(master_replica))
                new_read_var_op = graph_item.graph.get_operation_by_name(
                    new_read_var_op_name)
                consumers = get_consumers(read_var_op)
                update_consumers(consumers, read_var_op.outputs[0],
                                 new_read_var_op.outputs[0])
                update_colocation_group(consumers, read_var_op,
                                        new_read_var_op)

            # update the consumers of VarhandleOp to use the handle on replica=master_replica
            new_handle_op_name = ops.prepend_name_scope(
                ops.strip_name_scope(this_var_op_name, replica_prefix(i)),
                replica_prefix(master_replica))
            new_handle_op = graph_item.graph.get_operation_by_name(
                new_handle_op_name)
            handle_consumers = list(handle_consumers)
            update_consumers(handle_consumers, this_var_op.outputs[0],
                             new_handle_op.outputs[0])
            update_colocation_group(handle_consumers, this_var_op,
                                    new_handle_op)