示例#1
0
def test_expand_short():
    B, T, D, K = get_dim_vars('b t d k')

    x: 'btd' = np.ones((B, T, D))
    x: 'bktd' = x[:, None]
    expand_shape = et(src=(B, K, T, D), expansions='k->k*5', in_shape=x.shape)
    assert expand_shape == (-1, 5, -1, -1)
    print('test_expand_short: all assertions hold')
示例#2
0
def test_expand_short():
    x: 'btd' = np.ones((B, T, D))
    x: 'bktd' = x[:, None]
    #print (f'Expanding {(B,K,T,D)} by "k->k*5"')
    expand_shape = et(src=(B, K, T, D), expansions='k->k*5', in_shape=x.shape)
    #print (f'expansion shape: {expand_shape}\n')
    assert expand_shape == (-1, 5, -1, -1)
    print('test_expand_short: all assertions hold')
示例#3
0
def test_expand():
    B, T, D, K = get_dim_vars('b t d k')

    x: (B, T, D) = np.ones((B, T, D))
    x: (B, K, T, D) = x[:, None]

    expand_shape = et(src=(B, K, T, D),
                      expansions=[(K, K * 5)],
                      in_shape=x.shape)  #(B, K, T, D) -> (B, K*5, T, D)
    assert expand_shape == (-1, 5, -1, -1)
    print('test_expand: all assertions hold')
示例#4
0
def test_expand():
    x: (B, T, D) = np.ones((B, T, D))
    x: (B, K, T, D) = x[:, None]

    #print (f'Expanding {(B,K,T,D)} by {(K, K*5)}')
    expand_shape = et(src=(B, K, T, D),
                      expansions=[(K, K * 5)],
                      in_shape=x.shape)
    #print (f'expansion shape: {expand_shape}\n')
    assert expand_shape == (-1, 5, -1, -1)

    print('test_expand: all assertions hold')