def test_large_call(): """ Test calling twed """ # Call TWED dist = twed(A, TA, B, TB, nu, lamb, degree) print('Python cuTWED distance: {:f}'.format(dist)) assert np.allclose(dist, reference_result)
def test_basic_call_float(): """ Test the same call in single precision by feeding different types. """ dist = twed(A.astype(np.float32), TA.astype(np.float32), B.astype(np.float32), TB.astype(np.float32), nu, lamb, degree) print('Python cuTWED distance (single precision): {:f}'.format(dist)) assert np.allclose(dist, reference_result)
def test_multi_twed(): D = np.zeros(sz) print("Computing ctwed batch") for row, A in enumerate(AA): for col, B in enumerate(BB): if col < row: continue dist = twed(A, T, B, T, nu, lamb, degree) D[row][col] = dist assert np.allclose(D, DIST)
def test_cutwed(): """ Test running the synthetic control dataset""" DistanceMatrix = np.zeros((nseries, nseries)) for row, A in enumerate(tqdm(TS)): for col, B in enumerate(TS): if col < row: continue dist = twed(A, T, B, T, nu, lamb, degree) DistanceMatrix[row][col] = dist # print(f'Python CTWED distance:\t{row}\t{col}\t{dist:f}') name = 'synthetic_distance_matrix_cutwed' with open(f'{name}.npy', 'wb') as fh: np.save(fh, DistanceMatrix) # with sns.axes_style("white"): # sns.heatmap(DistanceMatrix, square=True, cmap="YlGnBu") # plt.savefig(f'{name}.png') return DistanceMatrix
def test_single_cutwed(): print("Computing cutwed") dist = twed(A, T, B, T, nu, lamb, degree) print(f"Distance: {dist}") assert np.allclose(single_ref, dist)
plt.plot() # In[9]: DistanceMatrix_ref[0, -1] # In[10]: # Test running the synthetic index dataset""" DistanceMatrix_cu = np.zeros((nseries, nseries)) for (row, A), (col, B) in tqdm(product(enumerate(TS), enumerate(TS)), total=len(TS)**2): if col < row: continue dist = twed(A, T, B, T, nu, lamb, degree) DistanceMatrix_cu[row][col] = dist with sns.axes_style("white"): sns.heatmap(DistanceMatrix_cu, square=True, cmap="YlGnBu") plt.plot() # In[11]: DistanceMatrix_cu[0, -1] # In[12]: np.max(np.abs(DistanceMatrix_cu - DistanceMatrix_ref)) # In[13]: