Exemplo n.º 1
0
    def _get_aggregated_sparse_grad(self, graph_item, var_op, grad,
                                    reduce_to_device, BFTaggregator):
        indices_op_name = strip_replica_prefix(get_op_name(grad.indices.name))
        values_op_name = strip_replica_prefix(get_op_name(grad.values.name))
        dense_shape_op_name = strip_replica_prefix(
            get_op_name(grad.dense_shape.name))

        indexed_slices_grads = []
        for i in range(self.num_replicas):
            indices_op = graph_item.graph.get_operation_by_name(
                ops.prepend_name_scope(indices_op_name, replica_prefix(i)))
            values_op = graph_item.graph.get_operation_by_name(
                ops.prepend_name_scope(values_op_name, replica_prefix(i)))
            dense_shape_op = graph_item.graph.get_operation_by_name(
                ops.prepend_name_scope(dense_shape_op_name, replica_prefix(i)))
            indexed_slices_grads.append(
                ops.IndexedSlices(
                    values_op.outputs[utils.get_index_from_tensor_name(
                        grad.values.name)],
                    indices_op.outputs[utils.get_index_from_tensor_name(
                        grad.indices.name)],
                    dense_shape_op.outputs[utils.get_index_from_tensor_name(
                        grad.dense_shape.name)]))

        return self._aggregate_sparse_gradients(var_op, reduce_to_device,
                                                indexed_slices_grads,
                                                values_op_name)
Exemplo n.º 2
0
    def _get_aggregated_dense_grad(self, graph_item, grad_name,
                                   reduce_to_device, BFTaggregator):
        grad_op_name = strip_replica_prefix(get_op_name(grad_name))
        output_idx = get_index_from_tensor_name(grad_name)
        grad_ops = [
            graph_item.graph.get_operation_by_name(
                ops.prepend_name_scope(grad_op_name, replica_prefix(i)))
            for i in range(self.num_replicas)
        ]

        # Aggregate gradients on `reduce_to_device` (usually CPU)
        with ops.device(reduce_to_device):
            #print("@@@@@@@@@@@@@@",[grad_op.outputs[output_idx] for grad_op in grad_ops])
            '''
            grad_sum_op_name = ops.prepend_name_scope(grad_op_name, u"%sAdd" % AUTODIST_PREFIX)
            grad_sum = math_ops.add_n([grad_op.outputs[output_idx] for grad_op in grad_ops], name=grad_sum_op_name)
            grad_avg_op_name = ops.prepend_name_scope(grad_op_name, u"%sDiv" % AUTODIST_PREFIX)
            grad_avg = math_ops.realdiv(grad_sum, self.num_replicas, name=grad_avg_op_name)
            '''

            # BFT Aggregator
            gradients = [grad_op.outputs[output_idx] for grad_op in grad_ops]
            grad_avg = BFTaggregator.aggregate(gradients)

            #print("$$$$$$$$$$$$$$",grad_avg)

        return grad_avg
Exemplo n.º 3
0
    def _update_gradient_consumers(new_graph_item, consumer_ops,
                                   control_consumer_ops, old_tensor_name,
                                   new_tensor):
        """Make gradient's consumers consume the aggregated gradient instead of the original one of replica_0."""
        # Get the original tensor (the one from replica 0) to replace
        old_op_name = strip_replica_prefix(get_op_name(old_tensor_name))
        replica_0_op_name = ops.prepend_name_scope(old_op_name,
                                                   replica_prefix(0))
        replica_0_op = new_graph_item.graph.get_operation_by_name(
            replica_0_op_name)
        output_idx = get_index_from_tensor_name(old_tensor_name)
        replica_0_tensor = replica_0_op.outputs[output_idx]

        update_consumers(consumer_ops, replica_0_tensor, new_tensor)
        update_control_consumers(control_consumer_ops, replica_0_tensor.op,
                                 new_tensor.op)
Exemplo n.º 4
0
def test_strip_replica_prefix():
    for name in ['my_op', '^my_op', 'my_tensor:0']:
        new_name = ops.prepend_name_scope(name, replica_prefix(12))
        assert strip_replica_prefix(new_name) == name