def grouping_gradients_apply(apply_func, grads_and_vars, *args, **kwargs):
    """NPU implemented gradients on grouping"""
    if global_npu_ctx() is None or not global_npu_ctx().is_cluster_worker():
        return apply_func(grads_and_vars, *args, **kwargs)

    grads_and_vars = tuple(grads_and_vars)  # grads_and_vars origin type is zip and can only be iter once

    op_list = []

    local_rank_id = global_npu_ctx().worker_id
    variables = []
    for _, var in grads_and_vars:
        variables.append(var)
    grouping_vars = GroupingVars(variables, global_npu_ctx().workers_num)
    local_grads_and_vars = []
    for grad, var in grads_and_vars:
        rank_id = grouping_vars.get_gid_by_var(var)
        if rank_id >= 0 and rank_id == local_rank_id:
            local_grads_and_vars.append((grad, var))
    apply_res = apply_func(local_grads_and_vars, *args, **kwargs)
    with ops.get_default_graph()._attr_scope(
            {"_weight_update_grouping": attr_value_pb2.AttrValue(b=True)}):
        for var in variables:
            rank_id = grouping_vars.get_gid_by_var(var)
            hccl.broadcast([var], rank_id, 0)
    for grad, var in grads_and_vars:
        rank_id = grouping_vars.get_gid_by_var(var)
        if rank_id >= 0 and rank_id != local_rank_id:
            op_list.append(grad)
    op_list.append(apply_res)
    return tf.group(op_list)
def grouping_broadcast(variables):
    """Grouping broadcast on cluster"""
    if global_npu_ctx() is None or not global_npu_ctx().is_cluster_worker():
        logging.info("Skip grouping broadcast as current process is not npu cluster worker")
        return variables
    grouping_vars = GroupingVars(variables, global_npu_ctx().workers_num)
    for var in variables:
        rank_id = grouping_vars.get_gid_by_var(var)
        hccl.broadcast([var], rank_id, 0)
Beispiel #3
0
def broadcast(values,
              root_rank,
              fusion=2,
              fusion_id=0,
              group="hccl_world_group"):
    if global_npu_ctx() is None or not global_npu_ctx().is_cluster_worker():
        tf.get_logger().info(
            "Skip broadcast value as current process is not npu cluster worker"
        )
        return

    tf.function(_broadcast)(values, root_rank, fusion, fusion_id, group)
Beispiel #4
0
def all_reduce(values,
               reduction,
               fusion=1,
               fusion_id=-1,
               group="hccl_world_group"):
    if global_npu_ctx() is None or not global_npu_ctx().is_cluster_worker():
        tf.get_logger().info(
            "Skip all-reduce value as current process is not npu cluster worker"
        )
        return values

    return tf.function(_all_reduce)(values, reduction, fusion, fusion_id,
                                    group)
Beispiel #5
0
def _npu_finite_status_after_executed(executed_ops):
    if not isinstance(executed_ops, (tuple, list)):
        executed_ops = [executed_ops]
    with ops.get_default_graph()._attr_scope(
        {"_npu_loss_scale": attr_value_pb2.AttrValue(b=True)}):
        with tf.control_dependencies(
            [v for v in executed_ops if v is not None]):
            current_status = gen_npu_ops.npu_alloc_float_status()
        assign_float_status = gen_npu_ops.npu_get_float_status(current_status)
        finite_status = gen_npu_ops.npu_clear_float_status(assign_float_status)
        if global_npu_ctx() and global_npu_ctx().workers_num > 1:
            with tf.control_dependencies([assign_float_status]):
                reduced_status = all_reduce(current_status, 'sum', fusion=0)
            return tf.reduce_all(tf.equal(reduced_status, finite_status))
        return tf.reduce_all(tf.equal(current_status, finite_status))
Beispiel #6
0
    def apply_gradients(self, grads_and_vars, name=None):
        """Apply gradients on variables"""
        if global_npu_ctx() is None:
            super().apply_gradients(grads_and_vars, name)

        grads_and_vars = tuple(
            grads_and_vars
        )  # grads_and_vars origin type is zip and can only be iter once
        grads = [g for g, _ in grads_and_vars]

        def apply_fn():
            wrapped_vars = _UnwrapPreventer([v for _, v in grads_and_vars])
            return self._apply_gradients(grads, wrapped_vars, name)

        def do_not_apply_fn():
            return self._optimizer.iterations.assign_add(1, read_value=False)

        if self.dynamic:
            loss_scale_update_op, should_apply_grads = _npu_compat_loss_scale_update(
                self._loss_scale, grads)
        else:
            loss_scale_update_op = tf.no_op()
            should_apply_grads = _npu_finite_status_after_executed(grads)

        self._last_step_finite.assign(should_apply_grads)
        maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn,
                                               do_not_apply_fn)
        return tf.group(maybe_apply_op, loss_scale_update_op)
Beispiel #7
0
def _all_reduce(values, reduction, fusion, fusion_id, group):
    workers_num = global_npu_ctx().workers_num

    mean_reduce = False
    if reduction == 'mean':
        mean_reduce = True
        reduction = 'sum'

    reduced_values = []
    for value in values:
        reduced_value = hccl_ops.allreduce(value, reduction, fusion, fusion_id,
                                           group)
        is_float = reduced_value.dtype in (tf.float16, tf.float32, tf.float64)
        if is_float:
            typed_workers_num = tf.cast(1.0 / float(workers_num),
                                        reduced_value.dtype)
        else:
            typed_workers_num = tf.cast(workers_num, reduced_value.dtype)
        with tf.control_dependencies([tf.group(*values)]):
            if mean_reduce:
                if is_float:
                    reduced_values.append(
                        tf.multiply(reduced_value, typed_workers_num))
                else:
                    reduced_values.append(
                        tf.divide(reduced_value, typed_workers_num))
            else:
                reduced_values.append(
                    tf.multiply(reduced_value, tf.cast(1,
                                                       reduced_value.dtype)))
    return reduced_values
Beispiel #8
0
def _all_reduce(values, reduction, fusion, fusion_id, group):
    workers_num = global_npu_ctx().workers_num

    mean_reduce = False
    if reduction == 'mean':
        mean_reduce = True
        reduction = 'sum'

    topo_guarder = tf.group(values)
    if isinstance(values, (
            list,
            tuple,
    )):
        reduced_values = []
        for value in values:
            reduced_value = hccl_ops.allreduce(value, reduction, fusion,
                                               fusion_id, group)
            typed_workers_num = tf.cast(workers_num, reduced_value.dtype)
            with tf.control_dependencies([topo_guarder]):
                if mean_reduce:
                    reduced_values.append(
                        tf.divide(reduced_value, typed_workers_num))
                else:
                    reduced_values.append(tf.identity(reduced_value))
        return reduced_values
    else:
        reduced_value = hccl_ops.allreduce(values, reduction, fusion,
                                           fusion_id, group)
        typed_workers_num = tf.cast(workers_num, reduced_value.dtype)
        with tf.control_dependencies([topo_guarder]):
            if mean_reduce:
                return tf.divide(reduced_value, typed_workers_num)
            else:
                return tf.identity(reduced_value)
Beispiel #9
0
def broadcast(values,
              root_rank=0,
              fusion=2,
              fusion_id=0,
              group="hccl_world_group"):
    """Broadcast value among cluster"""
    if global_npu_ctx() is None or not global_npu_ctx().is_cluster_worker():
        logging.info(
            "Skip broadcast as current process is not npu cluster worker")
        return
    if isinstance(values, (
            list,
            tuple,
    )):
        _broadcast(values, root_rank, fusion, fusion_id, group)
    else:
        _broadcast([values], root_rank, fusion, fusion_id, group)
Beispiel #10
0
def shard_and_rebatch_dataset(dataset, global_bs):
    if global_npu_ctx() is None or global_npu_ctx().workers_num <= 1:
        return dataset, global_bs
    if global_bs % global_npu_ctx().workers_num != 0:
        raise ValueError('Batch size must be divisible by num npus: {}'.format(
            global_npu_ctx().workers_num))

    batch_size = int(global_bs) / global_npu_ctx().workers_num
    dataset = dataset.shard(global_npu_ctx().workers_num,
                            global_npu_ctx().worker_id)

    return dataset, int(batch_size)
Beispiel #11
0
def all_reduce(values,
               reduction="mean",
               fusion=1,
               fusion_id=-1,
               group="hccl_world_group"):
    """NPU implemented all_reduce"""
    if global_npu_ctx() is None or not global_npu_ctx().is_cluster_worker():
        logging.info(
            "Skip all reduce as current process is not npu cluster worker")
        return values

    if isinstance(values, (
            list,
            tuple,
    )):
        is_list_value = True
    else:
        is_list_value = False
        values = [values]
    reduced_values = _all_reduce([v for v in values if v is not None],
                                 reduction, fusion, fusion_id, group)
    results = [None if v is None else reduced_values.pop(0) for v in values]
    return results if is_list_value else results[0]