def test_join_no_momentum(self): from ipu_sparse_ops import sparse_training a = (np.array([1, 2, 3]), np.array([4, 5, 6]), np.array([0.1, 0.2, 0.3])) b = (np.array([1, 2, 3]), np.array([4, 5, 6]), np.array([0.4, 0.5, 0.6])) g, m = sparse_training.join_triplets(a, b, None, 3) assert_equal(g[0], [1, 2, 3, 1, 2, 3]) assert_equal(g[1], [4, 5, 6, 4, 5, 6]) assert_equal(g[2], [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]) assert m == {}
def test_join(self): from ipu_sparse_ops import sparse_training a = (np.array([1, 2, 3]), np.array([2, 4, 6]), np.array([0.1, 0.2, 0.3])) b = (np.array([4, 5, 6]), np.array([8, 10, 12]), np.array([0.4, 0.5, 0.6])) m = {'momentum': np.array([1.4, 1.5, 1.6])} g, m = sparse_training.join_triplets(a, b, m, 3) assert_equal(g[0], [1, 2, 3, 4, 5, 6]) assert_equal(g[1], [2, 4, 6, 8, 10, 12]) assert_equal(g[2], [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]) assert_equal(m['momentum'], [1.4, 1.5, 1.6, 0, 0, 0])