def _get_aggregated_sparse_grad(self, graph_item, var_op, grad, reduce_to_device, BFTaggregator): indices_op_name = strip_replica_prefix(get_op_name(grad.indices.name)) values_op_name = strip_replica_prefix(get_op_name(grad.values.name)) dense_shape_op_name = strip_replica_prefix( get_op_name(grad.dense_shape.name)) indexed_slices_grads = [] for i in range(self.num_replicas): indices_op = graph_item.graph.get_operation_by_name( ops.prepend_name_scope(indices_op_name, replica_prefix(i))) values_op = graph_item.graph.get_operation_by_name( ops.prepend_name_scope(values_op_name, replica_prefix(i))) dense_shape_op = graph_item.graph.get_operation_by_name( ops.prepend_name_scope(dense_shape_op_name, replica_prefix(i))) indexed_slices_grads.append( ops.IndexedSlices( values_op.outputs[utils.get_index_from_tensor_name( grad.values.name)], indices_op.outputs[utils.get_index_from_tensor_name( grad.indices.name)], dense_shape_op.outputs[utils.get_index_from_tensor_name( grad.dense_shape.name)])) return self._aggregate_sparse_gradients(var_op, reduce_to_device, indexed_slices_grads, values_op_name)
def _get_aggregated_dense_grad(self, graph_item, grad_name, reduce_to_device, BFTaggregator): grad_op_name = strip_replica_prefix(get_op_name(grad_name)) output_idx = get_index_from_tensor_name(grad_name) grad_ops = [ graph_item.graph.get_operation_by_name( ops.prepend_name_scope(grad_op_name, replica_prefix(i))) for i in range(self.num_replicas) ] # Aggregate gradients on `reduce_to_device` (usually CPU) with ops.device(reduce_to_device): #print("@@@@@@@@@@@@@@",[grad_op.outputs[output_idx] for grad_op in grad_ops]) ''' grad_sum_op_name = ops.prepend_name_scope(grad_op_name, u"%sAdd" % AUTODIST_PREFIX) grad_sum = math_ops.add_n([grad_op.outputs[output_idx] for grad_op in grad_ops], name=grad_sum_op_name) grad_avg_op_name = ops.prepend_name_scope(grad_op_name, u"%sDiv" % AUTODIST_PREFIX) grad_avg = math_ops.realdiv(grad_sum, self.num_replicas, name=grad_avg_op_name) ''' # BFT Aggregator gradients = [grad_op.outputs[output_idx] for grad_op in grad_ops] grad_avg = BFTaggregator.aggregate(gradients) #print("$$$$$$$$$$$$$$",grad_avg) return grad_avg
def _update_gradient_consumers(new_graph_item, consumer_ops, control_consumer_ops, old_tensor_name, new_tensor): """Make gradient's consumers consume the aggregated gradient instead of the original one of replica_0.""" # Get the original tensor (the one from replica 0) to replace old_op_name = strip_replica_prefix(get_op_name(old_tensor_name)) replica_0_op_name = ops.prepend_name_scope(old_op_name, replica_prefix(0)) replica_0_op = new_graph_item.graph.get_operation_by_name( replica_0_op_name) output_idx = get_index_from_tensor_name(old_tensor_name) replica_0_tensor = replica_0_op.outputs[output_idx] update_consumers(consumer_ops, replica_0_tensor, new_tensor) update_control_consumers(control_consumer_ops, replica_0_tensor.op, new_tensor.op)
def test_strip_replica_prefix(): for name in ['my_op', '^my_op', 'my_tensor:0']: new_name = ops.prepend_name_scope(name, replica_prefix(12)) assert strip_replica_prefix(new_name) == name