示例#1
0
def apply_power_sign(var,
                     m,
                     grad,
                     lr,
                     logbase,
                     sign_decay,
                     beta,
                     target=utils.CCE):
    """
    Update 'var' according to the PowerSign update

    m_out = beta * m + (1 - beta) * grad
    var_out = var - lr_t * (exp(logbase * sign_decay * Sign(grad) * Sign(m_out)) * grad)

    Args:
        var (tvm.tensor.Tensor): A tensor of type float16 or float32
        m (tvm.tensor.Tensor): A tensor of same shape and type as var.
        grad (tvm.tensor.Tensor): A tensor of same shape and type as var.
        lr (tvm.tensor.Tensor): A scalar tensor of of same type as var.
        logbase (tvm.tensor.Tensor): A scalar tensor of of same type as var.
        sign_decay (tvm.tensor.Tensor): A scalar tensor of of same type as var.
        beta (tvm.tensor.Tensor): A scalar tensor of of same type as var.

    Returns:
        tvm.tensor.Tensor, updated var.
        tvm.tensor.Tensor, updated m.
    """
    # check dtypes
    utils.ops_dtype_check(var.dtype, utils.DtypeForDavinci.ALL_FLOAT)
    for i in (m, grad, lr, logbase, sign_decay, beta):
        utils.elemwise_dtype_check(var.dtype, i.dtype)

    # check shapes
    for i in (m, grad):
        utils.elemwise_shape_check(var.shape, i.shape)
    for i in (lr, logbase, sign_decay, beta):
        if tuple(get_shape(i)) != (1, ):
            raise RuntimeError(
                "lr, logbase, sign_decay and beta only support scalar tensor.")

    # compute
    out_var, out_m = _apply_power_sign_compute(var, m, grad, lr, logbase,
                                               sign_decay, beta)

    # reuse var, m
    out_var, binds_info = TensorUtils.inplace_set(var, out_var, "var_buf")
    out_m, binds_info2 = TensorUtils.inplace_set(m, out_m, "m_buf")
    binds_info.update(binds_info2)
    attrs = {utils.BINDS: binds_info}
    return out_var, out_m, attrs
def inplace_operate_bind(in_tensors, out_tensors, inplace_binds):
    """
    Some tensor need to be calculate inplace.

    Args:
        in_tensors (Union[list, tuple]): Origin input tensors.
        out_tensors (Union[list, tuple]): Origin output tensors.
        inplace_binds (tuple): Should be a tuple of tuples, the first value
                               of each element is input tensor index, the
                               second is output tensor index,
                               consist (in_id, out_id),
                               meanning out_id output tensor is inplace
                               update to in_id input tensor.
    Returns:
        Two elements tuple, one for output tensors, the other for tensor bind relations.
    """

    for in_id, out_id in inplace_binds:
        if in_id >= len(in_tensors) or out_id >= len(out_tensors):
            raise RuntimeError("Inplace binds is invalid, while there are {} "
                               "input tensors and {} output tensors, but get "
                               "bind {}.".format(len(in_tensors),
                                                 len(out_tensors),
                                                 inplace_binds))

    out_tensors = list(out_tensors)
    tensor_binds = {}
    inplaced_tensors = []

    for i, bind in enumerate(inplace_binds):
        in_tensor = in_tensors[bind[0]]
        out_tensor = out_tensors[bind[1]]
        out_tensor, binds_info = TensorUtils.inplace_set(
            in_tensor, out_tensor, buffer_name="inp_buf_{}".format(i))
        tensor_binds.update(binds_info)
        # Caculation is updated inplace in input tensor. But Mindspore
        # needs a related fake tensor(never use) in output list...
        out_tensor_shape = out_tensor.shape

        fake_tensor = akg.tvm.compute(
            out_tensor_shape,
            lambda *index, o_tensor=out_tensor: o_tensor(*index),
            name="fake_tensor_{}".format(i))

        out_tensors[bind[1]] = fake_tensor
        inplaced_tensors.append(out_tensor)

    return (tuple(out_tensors + inplaced_tensors), tensor_binds)
示例#3
0
def assign_add(data, value):
    """
    Computes data + value elementwise.

    Note:
        Only supports broadcast on input tensor value.

    Args:
        data (tvm.tensor.Tensor): Data tensor.
        value (tvm.tensor.Tensor): Value tensor, broadcast is allowed.

    Returns:
        fake_output: Invalid value, just to suit for framework.
        res: assign add result, tvm.tensor.Tensor, with same type and shape as input tensor data.
        attrs: dict.
    """
    input_shape = [x.value for x in data.shape]
    value_shape = [x.value for x in value.shape]

    if len(input_shape) < len(value_shape):
        raise RuntimeError("Do not support broadcast on input tensor data!")

    for i in range(len(value_shape)):
        if input_shape[len(input_shape) - i -
                       1] < value_shape[len(value_shape) - i - 1]:
            raise RuntimeError("Only support on input tensor value!")

    # broadcast adds extra compute and stage, avoid by checking the shapes before hand
    if len(value_shape) < len(input_shape) or value_shape != input_shape:
        broadcasted_value = akg.topi.broadcast_to(value, input_shape)
        res = akg.lang.cce.vadd(data, broadcasted_value)
    else:
        res = akg.lang.cce.vadd(data, value)
    res, binds_info = TensorUtils.inplace_set(data, res)
    attrs = {utils.BINDS: binds_info}
    return res, attrs
示例#4
0
def ApplyMomentum(weight,
                  grad,
                  accum,
                  lr_mat,
                  momt_mat,
                  use_nesterov=False,
                  grad_scale=1.0,
                  target=utils.CCE):
    """
    Apply momentum operator.

    Note:
        apply mometum is an op with inplace computing and binds is used.

    Args:
        weight (tvm.tensor.Tensor): weight tensor to be updated.
        grad (tvm.tensor.Tensor): gradient tensor.
        accum (tvm.tensor.Tensor): accum tensor to be updated.
        lr_mat (tvm.tensor.Tensor): tensor with shape (1,).
        momt_mat (tvm.tensor.Tensor): momt_mat tensor with shape (1,).
        use_nesterov (bool): Default value is False.
        grad_scale (float): Default value is 1.0

    Returns:
        fake_output: Invalid value, just suit for framework.
        accum_inplace: tvm.tensor.Tensor, updated accum.
        weight_inplace: tvm.tensor.Tensor, updated weight.
        atts: dict.
    """
    shape = [x.value for x in weight.shape]
    # shape check
    utils.elemwise_shape_check(weight.shape, grad.shape)
    utils.elemwise_shape_check(weight.shape, accum.shape)
    # dtype check
    utils.ops_dtype_check([weight.dtype, grad.dtype, accum.dtype],
                          utils.DtypeForDavinci.ALL_FLOAT)

    grad = akg.tvm.compute(
        shape,
        lambda *indice: grad(*indice) * akg.tvm.const(grad_scale, grad.dtype),
        name="grad")
    momt_accum = akg.tvm.compute(shape,
                                 lambda *indice: accum(*indice) * momt_mat[0],
                                 name="momt_accum")
    accum_inplace = akg.tvm.compute(
        shape,
        lambda *indice: momt_accum(*indice) + grad(*indice),
        name="accum_inplace")

    if not use_nesterov:
        sum_grad = akg.tvm.compute(
            shape,
            lambda *indice: accum_inplace(*indice) * lr_mat[0],
            name="nesterov_lr")
        weight_inplace = akg.tvm.compute(
            shape,
            lambda *indice: weight(*indice) - sum_grad(*indice),
            name="weight_inplace")
    else:
        weight_inplace = akg.tvm.compute(
            shape,
            lambda *indice: weight(*indice) - grad(*indice) * lr_mat[
                0] - accum_inplace(*indice) * momt_mat[0] * lr_mat[0],
            name="weight_inplace")
    weight_inplace, weight_binds_info = TensorUtils.inplace_set(
        weight, weight_inplace, "data_buf")
    accum_inplace, accum_binds_info = TensorUtils.inplace_set(
        accum, accum_inplace, "accum_buf")
    binds_info_all = weight_binds_info
    binds_info_all.update(accum_binds_info)
    attrs = {utils.BINDS: binds_info_all}
    fake_output = akg.tvm.compute(shape,
                                  lambda *indice: momt_accum(*indice),
                                  name="fake_output")
    # The variable fake_ouput is a invalid value, just to suit for framework of ME !
    # The variable weight_inplace is the updated value of weight .
    # The variable accum_inplace is the updated value of accum .
    return fake_output, accum_inplace, weight_inplace, attrs