def test_expand_short(): B, T, D, K = get_dim_vars('b t d k') x: 'btd' = np.ones((B, T, D)) x: 'bktd' = x[:, None] expand_shape = et(src=(B, K, T, D), expansions='k->k*5', in_shape=x.shape) assert expand_shape == (-1, 5, -1, -1) print('test_expand_short: all assertions hold')
def test_expand_short(): x: 'btd' = np.ones((B, T, D)) x: 'bktd' = x[:, None] #print (f'Expanding {(B,K,T,D)} by "k->k*5"') expand_shape = et(src=(B, K, T, D), expansions='k->k*5', in_shape=x.shape) #print (f'expansion shape: {expand_shape}\n') assert expand_shape == (-1, 5, -1, -1) print('test_expand_short: all assertions hold')
def test_expand(): B, T, D, K = get_dim_vars('b t d k') x: (B, T, D) = np.ones((B, T, D)) x: (B, K, T, D) = x[:, None] expand_shape = et(src=(B, K, T, D), expansions=[(K, K * 5)], in_shape=x.shape) #(B, K, T, D) -> (B, K*5, T, D) assert expand_shape == (-1, 5, -1, -1) print('test_expand: all assertions hold')
def test_expand(): x: (B, T, D) = np.ones((B, T, D)) x: (B, K, T, D) = x[:, None] #print (f'Expanding {(B,K,T,D)} by {(K, K*5)}') expand_shape = et(src=(B, K, T, D), expansions=[(K, K * 5)], in_shape=x.shape) #print (f'expansion shape: {expand_shape}\n') assert expand_shape == (-1, 5, -1, -1) print('test_expand: all assertions hold')