def grouping_gradients_apply(apply_func, grads_and_vars, *args, **kwargs): """NPU implemented gradients on grouping""" if global_npu_ctx() is None or not global_npu_ctx().is_cluster_worker(): return apply_func(grads_and_vars, *args, **kwargs) grads_and_vars = tuple(grads_and_vars) # grads_and_vars origin type is zip and can only be iter once op_list = [] local_rank_id = global_npu_ctx().worker_id variables = [] for _, var in grads_and_vars: variables.append(var) grouping_vars = GroupingVars(variables, global_npu_ctx().workers_num) local_grads_and_vars = [] for grad, var in grads_and_vars: rank_id = grouping_vars.get_gid_by_var(var) if rank_id >= 0 and rank_id == local_rank_id: local_grads_and_vars.append((grad, var)) apply_res = apply_func(local_grads_and_vars, *args, **kwargs) with ops.get_default_graph()._attr_scope( {"_weight_update_grouping": attr_value_pb2.AttrValue(b=True)}): for var in variables: rank_id = grouping_vars.get_gid_by_var(var) hccl.broadcast([var], rank_id, 0) for grad, var in grads_and_vars: rank_id = grouping_vars.get_gid_by_var(var) if rank_id >= 0 and rank_id != local_rank_id: op_list.append(grad) op_list.append(apply_res) return tf.group(op_list)
def grouping_broadcast(variables): """Grouping broadcast on cluster""" if global_npu_ctx() is None or not global_npu_ctx().is_cluster_worker(): logging.info("Skip grouping broadcast as current process is not npu cluster worker") return variables grouping_vars = GroupingVars(variables, global_npu_ctx().workers_num) for var in variables: rank_id = grouping_vars.get_gid_by_var(var) hccl.broadcast([var], rank_id, 0)
def broadcast(values, root_rank, fusion=2, fusion_id=0, group="hccl_world_group"): if global_npu_ctx() is None or not global_npu_ctx().is_cluster_worker(): tf.get_logger().info( "Skip broadcast value as current process is not npu cluster worker" ) return tf.function(_broadcast)(values, root_rank, fusion, fusion_id, group)
def all_reduce(values, reduction, fusion=1, fusion_id=-1, group="hccl_world_group"): if global_npu_ctx() is None or not global_npu_ctx().is_cluster_worker(): tf.get_logger().info( "Skip all-reduce value as current process is not npu cluster worker" ) return values return tf.function(_all_reduce)(values, reduction, fusion, fusion_id, group)
def _npu_finite_status_after_executed(executed_ops): if not isinstance(executed_ops, (tuple, list)): executed_ops = [executed_ops] with ops.get_default_graph()._attr_scope( {"_npu_loss_scale": attr_value_pb2.AttrValue(b=True)}): with tf.control_dependencies( [v for v in executed_ops if v is not None]): current_status = gen_npu_ops.npu_alloc_float_status() assign_float_status = gen_npu_ops.npu_get_float_status(current_status) finite_status = gen_npu_ops.npu_clear_float_status(assign_float_status) if global_npu_ctx() and global_npu_ctx().workers_num > 1: with tf.control_dependencies([assign_float_status]): reduced_status = all_reduce(current_status, 'sum', fusion=0) return tf.reduce_all(tf.equal(reduced_status, finite_status)) return tf.reduce_all(tf.equal(current_status, finite_status))
def apply_gradients(self, grads_and_vars, name=None): """Apply gradients on variables""" if global_npu_ctx() is None: super().apply_gradients(grads_and_vars, name) grads_and_vars = tuple( grads_and_vars ) # grads_and_vars origin type is zip and can only be iter once grads = [g for g, _ in grads_and_vars] def apply_fn(): wrapped_vars = _UnwrapPreventer([v for _, v in grads_and_vars]) return self._apply_gradients(grads, wrapped_vars, name) def do_not_apply_fn(): return self._optimizer.iterations.assign_add(1, read_value=False) if self.dynamic: loss_scale_update_op, should_apply_grads = _npu_compat_loss_scale_update( self._loss_scale, grads) else: loss_scale_update_op = tf.no_op() should_apply_grads = _npu_finite_status_after_executed(grads) self._last_step_finite.assign(should_apply_grads) maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn, do_not_apply_fn) return tf.group(maybe_apply_op, loss_scale_update_op)
def _all_reduce(values, reduction, fusion, fusion_id, group): workers_num = global_npu_ctx().workers_num mean_reduce = False if reduction == 'mean': mean_reduce = True reduction = 'sum' reduced_values = [] for value in values: reduced_value = hccl_ops.allreduce(value, reduction, fusion, fusion_id, group) is_float = reduced_value.dtype in (tf.float16, tf.float32, tf.float64) if is_float: typed_workers_num = tf.cast(1.0 / float(workers_num), reduced_value.dtype) else: typed_workers_num = tf.cast(workers_num, reduced_value.dtype) with tf.control_dependencies([tf.group(*values)]): if mean_reduce: if is_float: reduced_values.append( tf.multiply(reduced_value, typed_workers_num)) else: reduced_values.append( tf.divide(reduced_value, typed_workers_num)) else: reduced_values.append( tf.multiply(reduced_value, tf.cast(1, reduced_value.dtype))) return reduced_values
def _all_reduce(values, reduction, fusion, fusion_id, group): workers_num = global_npu_ctx().workers_num mean_reduce = False if reduction == 'mean': mean_reduce = True reduction = 'sum' topo_guarder = tf.group(values) if isinstance(values, ( list, tuple, )): reduced_values = [] for value in values: reduced_value = hccl_ops.allreduce(value, reduction, fusion, fusion_id, group) typed_workers_num = tf.cast(workers_num, reduced_value.dtype) with tf.control_dependencies([topo_guarder]): if mean_reduce: reduced_values.append( tf.divide(reduced_value, typed_workers_num)) else: reduced_values.append(tf.identity(reduced_value)) return reduced_values else: reduced_value = hccl_ops.allreduce(values, reduction, fusion, fusion_id, group) typed_workers_num = tf.cast(workers_num, reduced_value.dtype) with tf.control_dependencies([topo_guarder]): if mean_reduce: return tf.divide(reduced_value, typed_workers_num) else: return tf.identity(reduced_value)
def broadcast(values, root_rank=0, fusion=2, fusion_id=0, group="hccl_world_group"): """Broadcast value among cluster""" if global_npu_ctx() is None or not global_npu_ctx().is_cluster_worker(): logging.info( "Skip broadcast as current process is not npu cluster worker") return if isinstance(values, ( list, tuple, )): _broadcast(values, root_rank, fusion, fusion_id, group) else: _broadcast([values], root_rank, fusion, fusion_id, group)
def shard_and_rebatch_dataset(dataset, global_bs): if global_npu_ctx() is None or global_npu_ctx().workers_num <= 1: return dataset, global_bs if global_bs % global_npu_ctx().workers_num != 0: raise ValueError('Batch size must be divisible by num npus: {}'.format( global_npu_ctx().workers_num)) batch_size = int(global_bs) / global_npu_ctx().workers_num dataset = dataset.shard(global_npu_ctx().workers_num, global_npu_ctx().worker_id) return dataset, int(batch_size)
def all_reduce(values, reduction="mean", fusion=1, fusion_id=-1, group="hccl_world_group"): """NPU implemented all_reduce""" if global_npu_ctx() is None or not global_npu_ctx().is_cluster_worker(): logging.info( "Skip all reduce as current process is not npu cluster worker") return values if isinstance(values, ( list, tuple, )): is_list_value = True else: is_list_value = False values = [values] reduced_values = _all_reduce([v for v in values if v is not None], reduction, fusion, fusion_id, group) results = [None if v is None else reduced_values.pop(0) for v in values] return results if is_list_value else results[0]