def apply_ftrl(var, accum, linear, grad, lr, l1, l2, lr_power, target=utils.CCE): """ Ftrl-proximal optimization algorithm. Note: accum_new = accum + grad * grad linear_new = linear + grad - (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var x = clip(linear_new, -l1, l1) - linear_new y = accum_new^(-lr_power) / lr + 2 * l2 var_new = x / y Args: var (tvm.tensor.Tensor): The tensor to be updated. Should be float16 or float32. accum (tvm.tensor.Tensor): A tensor of same shape and type as var. Eatch entry in it must be greater or equal to zero. linear (tvm.tensor.Tensor): A tensor of same shape and type as var. grad (tvm.tensor.Tensor): A tensor of same shape and type as var. lr (tvm.tensor.Tensor): A scalar tensor of the same type as `var`. l1 (tvm.tensor.Tensor): A scalar tensor of the same type as `var`. l2 (tvm.tensor.Tensor): A scalar tensor of the same type as `var`. lr_power (tvm.tensor.Tensor): A scalar tensor of the same type as `var`. Value of it must be less or equal to zero. Returns: tvm.tensor.Tensor, updated var. tvm.tensor.Tensor, updated accum. tvm.tensor.Tensor, updated linear. """ # As vlog instruction on mini product has a percision problem and mini product used to infer # rather than train if product_is_mini(): raise RuntimeError("The apply_ftrl operator does not support the mini product") # check_shape utils.check_shape(var) shape = get_shape(var) for tensor in (accum, linear, grad): utils.elemwise_shape_check(shape, tensor.shape) sclar_shape = (1,) for sclar in (lr, l1, l2, lr_power): utils.elemwise_shape_check(sclar.shape, sclar_shape) # check dtype dtype = var.dtype utils.ops_dtype_check(dtype, [utils.DtypeForDavinci.FLOAT16, utils.DtypeForDavinci.FLOAT32]) for tensor in (var, accum, linear, grad, lr, l1, l2, lr_power): utils.elemwise_dtype_check(tensor.dtype, dtype) var_new, accum_new, linear_new = apply_ftrl_impl(var, accum, linear, grad, lr, l1, l2, None, lr_power, with_l2_shrinkage=False) # update by inplace (var_new, accum_new, linear_new), binds_info = \ TensorUtils.inplace_set_tensors((var, accum, linear), (var_new, accum_new, linear_new)) attrs = {utils.BINDS: binds_info} return var_new, accum_new, linear_new, attrs
def apply_proximal_adagrad(var, accum, lr, l1, l2, grad, target=utils.CCE): """ The FOBOS optimization algorithm with Adagrad learning rate. Note: accum_new = accum + grad * grad ada_lr = lr * rsqrt(accum_new) prox_var = var - ada_lr * grad if l1 > 0: var_new = Sign(prox_var)/(1+ada_lr*l2) * max{|prox_var|-ada_lr*l1,0} else: var_new = prox_var/(1+ada_lr*l2) Args: var (tvm.tensor.Tensor): The tensor to be updated. Should be float16 or float32. accum (tvm.tensor.Tensor): A tensor of same shape and type as var. Eatch entry in it must be greater or equal to zero. lr (tvm.tensor.Tensor): A scalar tensor of the same type as `var`. l1 (tvm.tensor.Tensor): A scalar tensor of the same type as `var`. l2 (tvm.tensor.Tensor): A scalar tensor of the same type as `var`. grad (tvm.tensor.Tensor): A tensor of same shape and type as var. Returns: tvm.tensor.Tensor, updated var. tvm.tensor.Tensor, updated accum. """ # check_shape utils.check_shape(var) shape = get_shape(var) for tensor in (accum, grad): utils.elemwise_shape_check(shape, tensor.shape) sclar_shape = (1, ) for sclar in (lr, l1, l2): utils.elemwise_shape_check(sclar.shape, sclar_shape) # check dtype dtype = var.dtype utils.ops_dtype_check( dtype, [utils.DtypeForDavinci.FLOAT16, utils.DtypeForDavinci.FLOAT32]) for tensor in (var, accum, lr, l1, l2, grad): utils.elemwise_dtype_check(tensor.dtype, dtype) var_new, accum_new = _apply_proximal_adagrad_compute( var, accum, lr, l1, l2, grad) (var_new, accum_new), binds_info = TensorUtils.inplace_set_tensors( [var, accum], [var_new, accum_new]) attrs = {utils.BINDS: binds_info} return var_new, accum_new, attrs
def apply_adam(var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, use_nesterov=False, target=utils.CCE): """ Adam and Nadam optimization algorithm. Note: lr_t = lr*sqrt(1-beta2_power)/(1-beta1_power) m_new = m + (1-beta1)*(grad-m) v_new = v + (1-beta2)*(grad*grad-v) if user_nesterov == True: var_new = var - lr_t*(m_new*beta1 + (1-beta1)*grad) / (epsilon + sqrt(v_new)) else: var_new = var - lr_t*m_new / (epsilon + sqrt(v_new)) Args: var (tvm.tensor.Tensor): The tensor to be updated. Should be float16 or float32. m (tvm.tensor.Tensor): The first moment estimate. A tensor of same shape and type as var. v (tvm.tensor.Tensor): The second moment estimate. A tensor of same shape and type as var. beta1_power (tvm.tensor.Tensor): A scalar tensor of the same type as `var`. beta2_power (tvm.tensor.Tensor): A scalar tensor of the same type as `var`. lr (tvm.tensor.Tensor): The learning rate. A scalar tensor of the same type as `var`. beta1(tvm.tensor.Tensor): A tensor with shape (1,) and type is same as var. beta2(tvm.tensor.Tensor): A scalar tensor of the same type as `var`. epsilon(tvm.tensor.Tensor): A scalar tensor of the same type as `var`. grad (tvm.tensor.Tensor): A tensor of same shape and type as var. use_nesterov(bool): Default value is False. If use_nesterov is True, the Nadam algorithm be implemented, otherwise the adam algorithm be implemented. Returns: tvm.tensor.Tensor, updated var. tvm.tensor.Tensor, updated m. tvm.tensor.Tensor, updated v. """ # check shape utils.check_shape(var) shape = get_shape(var) for tensor in (m, v, grad): utils.elemwise_shape_check(shape, tensor.shape) sclar_shape = (1, ) for sclar in (beta1_power, beta2_power, lr, beta1, beta2, epsilon): utils.elemwise_shape_check(sclar.shape, sclar_shape) # check dtype dtype = var.dtype utils.ops_dtype_check( dtype, [utils.DtypeForDavinci.FLOAT16, utils.DtypeForDavinci.FLOAT32]) for tensor in (var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): utils.elemwise_dtype_check(tensor.dtype, dtype) var_new, m_new, v_new = _apply_adam_compute(var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, use_nesterov) # update by inplace (var_new, m_new, v_new), binds_info = TensorUtils.inplace_set_tensors( [var, m, v], [var_new, m_new, v_new]) attrs = {utils.BINDS: binds_info} return var_new, m_new, v_new, attrs