예제 #1
0
 def _get_accum_apply_and_agg_grad(var_op, grad, indices, dense_shape):
     if indices is None:
         tensor = variable_utils.get_read_var_tensor(var_op)
         grad_accum = data_flow_ops.ConditionalAccumulator(
             grad.dtype,
             shape=tensor.get_shape(),
             shared_name=var_op.name + "/grad_accum")
         # Get a copy of consumers list before creating accum_apply_op
         grad_consumers = list(grad.consumers())
         accum_apply_op = grad_accum.apply_grad(grad,
                                                local_step=MAX_INT64,
                                                name=grad.op.name +
                                                '_accum_apply_grad')
         agg_grad = grad_accum.take_grad(num_accum_required,
                                         name=var_op.name +
                                         '_take_grad')
         update_consumers(grad_consumers, grad, agg_grad)
         update_control_consumers(get_control_consumers(grad.op),
                                  grad.op, agg_grad.op)
     else:
         grad_indexed_slices = ops.IndexedSlices(
             values=grad, indices=indices, dense_shape=dense_shape)
         grad_accum = data_flow_ops.SparseConditionalAccumulator(
             grad.dtype,
             shape=grad.shape,
             shared_name=var_op.name + "/grad_accum")
         # Get a copy of consumers list before creating accum_apply_op
         indices_consumers = list(indices.consumers())
         grad_consumers = list(grad.consumers())
         accum_apply_op = grad_accum.apply_indexed_slices_grad(
             grad_indexed_slices,
             local_step=MAX_INT64,
             name=grad.op.name + '_accum_apply_grad')
         agg_grad = grad_accum.take_indexed_slices_grad(
             num_accum_required, name=var_op.name + '_take_grad')
         agg_indices = agg_grad.indices
         if indices.dtype != agg_grad.indices.dtype:
             agg_indices = math_ops.cast(agg_grad.indices,
                                         indices.dtype)
         agg_grad = ops.IndexedSlices(values=agg_grad.values,
                                      indices=agg_indices,
                                      dense_shape=agg_grad.dense_shape)
         assert isinstance(agg_grad, ops.IndexedSlices)
         update_consumers(indices_consumers, indices, agg_grad.indices)
         update_consumers(grad_consumers, grad, agg_grad.values)
         update_control_consumers(get_control_consumers(indices.op),
                                  indices.op, agg_grad.indices.op)
         update_control_consumers(get_control_consumers(grad.op),
                                  grad.op, agg_grad.values.op)
     return accum_apply_op, agg_grad
예제 #2
0
    def _get_aggregated_sparse_grad(self, graph_item, var_op, grad,
                                    reduce_to_device, BFTaggregator):
        indices_op_name = strip_replica_prefix(get_op_name(grad.indices.name))
        values_op_name = strip_replica_prefix(get_op_name(grad.values.name))
        dense_shape_op_name = strip_replica_prefix(
            get_op_name(grad.dense_shape.name))

        indexed_slices_grads = []
        for i in range(self.num_replicas):
            indices_op = graph_item.graph.get_operation_by_name(
                ops.prepend_name_scope(indices_op_name, replica_prefix(i)))
            values_op = graph_item.graph.get_operation_by_name(
                ops.prepend_name_scope(values_op_name, replica_prefix(i)))
            dense_shape_op = graph_item.graph.get_operation_by_name(
                ops.prepend_name_scope(dense_shape_op_name, replica_prefix(i)))
            indexed_slices_grads.append(
                ops.IndexedSlices(
                    values_op.outputs[utils.get_index_from_tensor_name(
                        grad.values.name)],
                    indices_op.outputs[utils.get_index_from_tensor_name(
                        grad.indices.name)],
                    dense_shape_op.outputs[utils.get_index_from_tensor_name(
                        grad.dense_shape.name)]))

        return self._aggregate_sparse_gradients(var_op, reduce_to_device,
                                                indexed_slices_grads,
                                                values_op_name)
예제 #3
0
    def _aggregate_sparse_gradients(self, var_op, reduce_to_device,
                                    indexed_slices_grads, values_op_name):
        with ops.device(reduce_to_device):
            grad_accum_op_name = ops.prepend_name_scope(
                values_op_name, u"%sAccum" % AUTODIST_PREFIX)
            grad_accum = data_flow_ops.SparseConditionalAccumulator(
                dtype=indexed_slices_grads[0].values.dtype,
                shape=var_op.outputs[0].shape,
                shared_name=grad_accum_op_name,
                name=grad_accum_op_name)
            accum_apply_ops = [
                grad_accum.apply_indexed_slices_grad(
                    indexed_slices_grads[i],
                    MAX_INT64,
                    name=ops.prepend_name_scope(
                        values_op_name, u"%s-Accum-Apply" % replica_prefix(i)))
                for i in range(self.num_replicas)
            ]
            take_grad_op_name = ops.prepend_name_scope(
                values_op_name, u"%sTake-Grad" % AUTODIST_PREFIX)
            with ops.control_dependencies(accum_apply_ops):
                take_grad = grad_accum.take_indexed_slices_grad(
                    self.num_replicas, name=take_grad_op_name)

            new_indices = take_grad.indices
            new_values = take_grad.values
            new_dense_shape = take_grad.dense_shape
            if indexed_slices_grads[0].indices.dtype != new_indices.dtype:
                new_indices = math_ops.cast(
                    new_indices,
                    indexed_slices_grads[0].indices.dtype,
                    name=ops.prepend_name_scope(
                        values_op_name,
                        u"%sTake-Grad-Cast-Indices" % AUTODIST_PREFIX))
            if indexed_slices_grads[
                    0].dense_shape.dtype != new_dense_shape.dtype:
                new_dense_shape = math_ops.cast(
                    new_dense_shape,
                    indexed_slices_grads[0].dense_shape.dtype,
                    name=ops.prepend_name_scope(
                        values_op_name,
                        u"%sTake-Grad-Cast-Shape" % AUTODIST_PREFIX))
        return ops.IndexedSlices(new_values, new_indices, new_dense_shape)
예제 #4
0
    def _create_new_vars(self, new_graph_item, vars_to_partition, unpartitioned_vars):
        """
        Constructs new partitioned variables in `new_graph_item`.

        Fixes each var's corresponding gradient by splitting the gradient.

        Fixes the optimizer by just constructing a new one using the new variables.

        Args:
            new_graph_item (GraphItem): The GraphItem in which to construct the new variables and ops.
            vars_to_partition (Dict): Mapping of variable names to number of shards for vars to be partitioned.
            unpartitioned_vars (Dict): Mapping from variable name to gradient name of unpartitioned vars.

        Returns:
            List of new variables.
        """
        new_grads, new_vars = [], []
        partition_config = {}
        with new_graph_item.graph.as_default():
            for var_name, (partition_str, reduction_destinations) in vars_to_partition.items():
                var_op_name = get_op_name(var_name)
                var_op = self.graph_item.graph.get_operation_by_name(var_op_name)
                var = self.graph_item.trainable_var_op_to_var[var_op]
                gradient = self.graph_item.var_op_name_to_grad_info[var_op_name][0]

                # Create partitioned variable and split gradients
                pc = PartitionerConfig(partition_str=partition_str)
                partition_config[var_op_name] = pc

                # Now check compatibility
                if isinstance(gradient, ops.IndexedSlices) and pc.axis != 0:
                    raise ValueError('Embedding variables can only be partitioned along the first axis due to '
                                     'the limitation on the `embedding_lookup_v2` op.')

                initial_value = new_graph_item.graph.get_tensor_by_name(var.initial_value.name)
                # NOTE: to enable the saver, for now we only support partition on the one dimension
                # https://github.com/tensorflow/tensorflow/blob/r2.0/tensorflow/python/ops/variables.py#L2915
                partitioned_var = vs.get_variable(var_op.name, shape=None, initializer=initial_value,
                                                  partitioner=lambda pconf=pc, **unused_kwargs: pconf.partition_list,
                                                  validate_shape=False, use_resource=True)
                var_list = partitioned_var._variable_list

                # Distribute the partitioned variable if they have a PS synchornizer
                # Actually maybe this is not necessary
                for var_slice, device in zip(var_list, reduction_destinations):
                    if device:
                        var_slice.op._set_device_from_string(device)

                if isinstance(gradient, ops.IndexedSlices):
                    # Sparse variable
                    new_grad = ops.IndexedSlices(
                        indices=new_graph_item.graph.get_tensor_by_name(gradient.indices.name),
                        values=new_graph_item.graph.get_tensor_by_name(gradient.values.name),
                        dense_shape=new_graph_item.graph.get_tensor_by_name(gradient.dense_shape.name)
                    )
                    split_grad = self._split_indexed_slices_v2(new_grad, len(var_list), var.shape[0],
                                                               name=f"gradients/splits/sparse_split_{var_op_name}")
                else:
                    new_grad = new_graph_item.graph.get_tensor_by_name(gradient.name)

                    # sometimes new_grad will have polymorphic shape (None), so we use the shape of the original var
                    split_grad = self._split_tensor_v2(new_grad, pc.num_shards, var.shape, pc.axis,
                                                       name=f"gradients/splits/split_{var_op_name}")
                self._handle_read(new_graph_item, var_op, partitioned_var)
                self._update_node_config(var, var_list)

                self.info.update_variables(var_list, replace=False)
                new_vars.extend(var_list)
                new_grads.extend(split_grad)
                new_graph_item.extend_gradient_info(split_grad, var_list)
                new_graph_item.pop_gradient_info(var.name)
        new_graph_item.info = self.info.copy()
        all_vars, all_grads = new_vars, new_grads
        for var, grad in unpartitioned_vars.items():
            if isinstance(grad, ops.IndexedSlices):
                # Sparse variable
                grad = ops.IndexedSlices(
                    indices=new_graph_item.graph.get_tensor_by_name(grad.indices.name),
                    values=new_graph_item.graph.get_tensor_by_name(grad.values.name),
                    dense_shape=new_graph_item.graph.get_tensor_by_name(grad.dense_shape.name)
                )
            else:
                grad = new_graph_item.graph.get_tensor_by_name(grad.name)
            all_grads.append(grad)
            var = new_graph_item.trainable_var_op_to_var[new_graph_item.graph.get_operation_by_name(get_op_name(var))]
            # TensorFlow expects the following to not mess autodist with the tf.distribute
            if (not hasattr(var, "_distribute_strategy")) or var._distribute_strategy:
                setattr(var, "_distribute_strategy", None)
            all_vars.append(var)
        with new_graph_item.graph.as_default():
            optimizer = self.graph_item.optimizer(*self.graph_item.optimizer_args[1:],
                                                  **self.graph_item.optimizer_kwargs)
            _ = optimizer.apply_gradients(zip(all_grads, all_vars))
        return new_vars, partition_config