예제 #1
0
파일: __init__.py 프로젝트: ssokhey/CNTK
def squared_error(output, target, name=''):
    '''
    This operation computes the sum of the squared difference between elements
    in the two input matrices. The result is a scalar (i.e., one by one matrix).
    This is often used as a training criterion.

    Example:
        >>> i1 = C.input_variable((1,2))
        >>> i2 = C.input_variable((1,2))
        >>> C.squared_error(i1,i2).eval({i1:np.asarray([[[2., 1.]]], dtype=np.float32), i2:np.asarray([[[4., 6.]]], dtype=np.float32)})
        array([ 29.], dtype=float32)

        >>> C.squared_error(i1,i2).eval({i1:np.asarray([[[1., 2.]]], dtype=np.float32), i2:np.asarray([[[1., 2.]]], dtype=np.float32)})
        array([ 0.], dtype=float32)

    Args:
        output: the output values from the network
        target: it is usually a one-hot vector where the hot bit
         corresponds to the label index
        name (str, optional): the name of the Function instance in the network
    Returns:
        :class:`~cntk.ops.functions.Function`
    '''
    from cntk.cntk_py import squared_error
    dtype = get_data_type(output, target)
    output = sanitize_input(output, dtype)
    target = sanitize_input(target, dtype)
    return squared_error(output, target, name)
예제 #2
0
def squared_error(output, target, name=''):
    '''
    This operation computes the sum of the squared difference between elements
    in the two input matrices. The result is a scalar (i.e., one by one matrix).
    This is often used as a training criterion.

    Example:
        >>> i1 = C.input_variable((1,2))
        >>> i2 = C.input_variable((1,2))
        >>> C.squared_error(i1,i2).eval({i1:np.asarray([[[[2., 1.]]]], dtype=np.float32), i2:np.asarray([[[[4., 6.]]]], dtype=np.float32)})
        array([[ 29.]], dtype=float32)

        >>> C.squared_error(i1,i2).eval({i1:np.asarray([[[[1., 2.]]]], dtype=np.float32), i2:np.asarray([[[[1., 2.]]]], dtype=np.float32)})
        array([[ 0.]], dtype=float32)

    Args:
        output: the output values from the network
        target: it is usually a one-hot vector where the hot bit
         corresponds to the label index
        name (str, optional): the name of the Function instance in the network
    Returns:
        :class:`~cntk.ops.functions.Function`
    '''
    from cntk.cntk_py import squared_error
    dtype = get_data_type(output, target)
    output = sanitize_input(output, dtype)
    target = sanitize_input(target, dtype)
    return squared_error(output, target, name)
예제 #3
0
def dice_coefficient(x, y):
    # https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
    #intersection = C.reduce_sum(x - y)
    err = squared_error(x, y, "se")
    return err  #2 * intersection / (C.reduce_sum(x) + C.reduce_sum(y))