def test_intersect_indices_sparse(self): obs_code = tf.SparseTensor(indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 2, 0]], values=[ 'loinc:1', 'loinc:1', 'loinc:2', 'loinc:4', 'loinc:4', 'loinc:2', 'loinc:1', 'loinc:4' ], dense_shape=[3, 4, 1]) obs_harm_code = tf.SparseTensor(indices=[[0, 0, 0], [0, 1, 0], [1, 1, 0], [2, 1, 0], [2, 2, 0]], values=[ 'pulse', 'pulse', 'blood_pressure', 'temperature', 'temperature' ], dense_shape=[3, 2, 3]) indices = [[0, 0, 0], [0, 1, 0], [1, 0, 0], [2, 1, 0]] values = [b'loinc:1', b'loinc:1', b'loinc:4', b'loinc:4'] dense_shape = [3, 2, 1] new_obs_code = input_fn._intersect_indices(obs_code, obs_harm_code) with self.test_session() as sess: acutal_obs_code = sess.run(new_obs_code) self.assertAllEqual(values, acutal_obs_code.values) self.assertAllEqual(indices, acutal_obs_code.indices) self.assertAllEqual(dense_shape, acutal_obs_code.dense_shape)
def test_intersect_indices_dense(self): delta_time = tf.constant([[1, 2, 3, 4, 5], [10, 20, 30, 40, 0], [100, 200, 300, 0, 0]]) obs_harm_code = tf.SparseTensor( indices=[[0, 0, 0], [0, 1, 0], [1, 1, 2], [2, 1, 0]], values=['pulse', 'pulse', 'blood_pressure', 'temperature'], dense_shape=[3, 2, 3]) expected_delta_time = [[[1], [2]], [[20], [0]], [[200], [0]]] new_delta_time = input_fn._intersect_indices(delta_time, obs_harm_code) with self.test_session() as sess: acutal_delta_time = sess.run(new_delta_time) self.assertAllClose(expected_delta_time, acutal_delta_time)