예제 #1
0
def test_convert_variable():
    test_type = TensorType(config.floatX, [False, False])
    test_var = test_type()

    test_type2 = TensorType(config.floatX, [True, False])
    test_var2 = test_type2()

    res = test_type.convert_variable(test_var)
    assert res is test_var

    res = test_type.convert_variable(test_var2)
    assert res is test_var2

    res = test_type2.convert_variable(test_var)
    assert res.type == test_type2

    test_type3 = TensorType(config.floatX, [True, False, True])
    test_var3 = test_type3()

    res = test_type2.convert_variable(test_var3)
    assert res is None

    const_var = at.as_tensor([[1, 2], [3, 4]], dtype=config.floatX)
    res = test_type.convert_variable(const_var)
    assert res is const_var
예제 #2
0
def test_fixed_shape_convert_variable():
    # These are equivalent types
    t1 = TensorType("float64", (True, True))
    t2 = TensorType("float64", (1, 1))

    assert t1 == t2
    assert t1.shape == t2.shape

    t2_var = t2()
    res = t2.convert_variable(t2_var)
    assert res is t2_var

    res = t1.convert_variable(t2_var)
    assert res is t2_var

    t1_var = t1()
    res = t2.convert_variable(t1_var)
    assert res is t1_var

    t3 = TensorType("float64", (False, True))
    t3_var = t3()
    res = t2.convert_variable(t3_var)
    assert isinstance(res.owner.op, Rebroadcast)

    t3 = TensorType("float64", (False, False))
    t4 = TensorType("float64", (3, 2))
    t4_var = t4()
    assert t3.shape == (None, None)
    res = t3.convert_variable(t4_var)
    assert res.type == t4
    assert res.type.shape == (3, 2)