예제 #1
0
def test_softmax_expr():
    expr = ast.SoftmaxExpr([ast.NumVal(2.0), ast.NumVal(3.0)])

    interpreter = JavascriptInterpreter()

    expected_code = """
function score(input) {
    return softmax([2.0, 3.0]);
}
function softmax(x) {
    let size = x.length;
    let result = new Array(size);
    let max = x[0];
    for (let i = 1; i < size; ++i) {
        if (x[i] > max)
            max = x[i];
    }
    let sum = 0.0;
    for (let i = 0; i < size; ++i) {
        result[i] = Math.exp(x[i] - max);
        sum += result[i];
    }
    for (let i = 0; i < size; ++i)
        result[i] /= sum;
    return result;
}
"""

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
예제 #2
0
def test_dependable_condition():
    left = ast.BinNumExpr(
        ast.IfExpr(
            ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
            ast.NumVal(1), ast.NumVal(2)), ast.NumVal(2), ast.BinNumOpType.ADD)

    right = ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.DIV)
    bool_test = ast.CompExpr(left, right, ast.CompOpType.GTE)

    expr = ast.IfExpr(bool_test, ast.NumVal(1), ast.FeatureRef(0))

    expected_code = """
function score(input) {
    var var0;
    var var1;
    if ((1.0) == (1.0)) {
        var1 = 1.0;
    } else {
        var1 = 2.0;
    }
    if (((var1) + (2.0)) >= ((1.0) / (2.0))) {
        var0 = 1.0;
    } else {
        var0 = input[0];
    }
    return var0;
}
"""

    interpreter = JavascriptInterpreter()

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
예제 #3
0
def test_bin_vector_num_expr():
    expr = ast.BinVectorNumExpr(ast.VectorVal([ast.NumVal(1),
                                               ast.NumVal(2)]), ast.NumVal(1),
                                ast.BinNumOpType.MUL)

    interpreter = JavascriptInterpreter()

    expected_code = """
function score(input) {
    return mulVectorNumber([1.0, 2.0], 1.0);
}
function addVectors(v1, v2) {
    let result = new Array(v1.length);
    for (let i = 0; i < v1.length; i++) {
        result[i] = v1[i] + v2[i];
    }
    return result;
}
function mulVectorNumber(v1, num) {
    let result = new Array(v1.length);
    for (let i = 0; i < v1.length; i++) {
        result[i] = v1[i] * num;
    }
    return result;
}
"""
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
예제 #4
0
def test_atan_expr():
    expr = ast.AtanExpr(ast.NumVal(2.0))

    expected_code = """
function score(input) {
    return Math.atan(2.0);
}
"""

    interpreter = JavascriptInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
예제 #5
0
def test_raw_array():
    expr = ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])

    expected_code = """
function score(input) {
    return [3.0, 4.0];
}
"""

    interpreter = JavascriptInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
예제 #6
0
def test_pow_expr():
    expr = ast.PowExpr(ast.NumVal(2.0), ast.NumVal(3.0))

    expected_code = """
function score(input) {
    return Math.pow(2.0, 3.0);
}
"""

    interpreter = JavascriptInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
예제 #7
0
def test_log1p_expr():
    expr = ast.Log1pExpr(ast.NumVal(2.0))

    interpreter = JavascriptInterpreter()

    expected_code = """
function score(input) {
    return Math.log1p(2.0);
}
"""

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
예제 #8
0
def test_abs_expr():
    expr = ast.AbsExpr(ast.NumVal(-1.0))

    interpreter = JavascriptInterpreter()

    expected_code = """
function score(input) {
    return Math.abs(-1.0);
}
"""

    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
예제 #9
0
def test_reused_expr():
    reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True)
    expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV)

    expected_code = """
function score(input) {
    var var0;
    var0 = Math.exp(1.0);
    return (var0) / (var0);
}
"""

    interpreter = JavascriptInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
예제 #10
0
def test_bin_num_expr():
    expr = ast.BinNumExpr(
        ast.BinNumExpr(ast.FeatureRef(0), ast.NumVal(-2),
                       ast.BinNumOpType.DIV), ast.NumVal(2),
        ast.BinNumOpType.MUL)

    expected_code = """
function score(input) {
    return ((input[0]) / (-2.0)) * (2.0);
}
"""

    interpreter = JavascriptInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
예제 #11
0
def test_sigmoid_expr():
    expr = ast.SigmoidExpr(ast.NumVal(2.0))

    expected_code = """
function score(input) {
    return sigmoid(2.0);
}
function sigmoid(x) {
    if (x < 0.0) {
        let z = Math.exp(x);
        return z / (1.0 + z);
    }
    return 1.0 / (1.0 + Math.exp(-x));
}
"""

    interpreter = JavascriptInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
예제 #12
0
def test_if_expr():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ),
        ast.NumVal(2), ast.NumVal(3))

    expected_code = """
function score(input) {
    var var0;
    if ((1.0) === (input[0])) {
        var0 = 2.0;
    } else {
        var0 = 3.0;
    }
    return var0;
}
"""

    interpreter = JavascriptInterpreter()
    assert_code_equal(interpreter.interpret(expr), expected_code)
예제 #13
0
def test_nested_condition():
    left = ast.BinNumExpr(
        ast.IfExpr(
            ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
            ast.NumVal(1), ast.NumVal(2)), ast.NumVal(2), ast.BinNumOpType.ADD)

    bool_test = ast.CompExpr(ast.NumVal(1), left, ast.CompOpType.EQ)

    expr_nested = ast.IfExpr(bool_test, ast.FeatureRef(2), ast.NumVal(2))

    expr = ast.IfExpr(bool_test, expr_nested, ast.NumVal(2))

    expected_code = """
function score(input) {
    var var0;
    var var1;
    if ((1.0) == (1.0)) {
        var1 = 1.0;
    } else {
        var1 = 2.0;
    }
    if ((1.0) == ((var1) + (2.0))) {
        var var2;
        if ((1.0) == (1.0)) {
            var2 = 1.0;
        } else {
            var2 = 2.0;
        }
        if ((1.0) == ((var2) + (2.0))) {
            var0 = input[2];
        } else {
            var0 = 2.0;
        }
    } else {
        var0 = 2.0;
    }
    return var0;
}
"""

    interpreter = JavascriptInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)
예제 #14
0
def test_multi_output():
    expr = ast.IfExpr(
        ast.CompExpr(ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ),
        ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
        ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))

    expected_code = """
function score(input) {
    var var0;
    if ((1.0) == (1.0)) {
        var0 = [1.0, 2.0];
    } else {
        var0 = [3.0, 4.0];
    }
    return var0;
}
"""

    interpreter = JavascriptInterpreter()
    utils.assert_code_equal(interpreter.interpret(expr), expected_code)