def test_intersect_indices_sparse(self):
     obs_code = tf.SparseTensor(indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0],
                                         [1, 1, 0], [1, 2, 0], [1, 3, 0],
                                         [2, 0, 0], [2, 2, 0]],
                                values=[
                                    'loinc:1', 'loinc:1', 'loinc:2',
                                    'loinc:4', 'loinc:4', 'loinc:2',
                                    'loinc:1', 'loinc:4'
                                ],
                                dense_shape=[3, 4, 1])
     obs_harm_code = tf.SparseTensor(indices=[[0, 0, 0], [0, 1,
                                                          0], [1, 1, 0],
                                              [2, 1, 0], [2, 2, 0]],
                                     values=[
                                         'pulse', 'pulse', 'blood_pressure',
                                         'temperature', 'temperature'
                                     ],
                                     dense_shape=[3, 2, 3])
     indices = [[0, 0, 0], [0, 1, 0], [1, 0, 0], [2, 1, 0]]
     values = [b'loinc:1', b'loinc:1', b'loinc:4', b'loinc:4']
     dense_shape = [3, 2, 1]
     new_obs_code = input_fn._intersect_indices(obs_code, obs_harm_code)
     with self.test_session() as sess:
         acutal_obs_code = sess.run(new_obs_code)
         self.assertAllEqual(values, acutal_obs_code.values)
         self.assertAllEqual(indices, acutal_obs_code.indices)
         self.assertAllEqual(dense_shape, acutal_obs_code.dense_shape)
 def test_intersect_indices_dense(self):
     delta_time = tf.constant([[1, 2, 3, 4, 5], [10, 20, 30, 40, 0],
                               [100, 200, 300, 0, 0]])
     obs_harm_code = tf.SparseTensor(
         indices=[[0, 0, 0], [0, 1, 0], [1, 1, 2], [2, 1, 0]],
         values=['pulse', 'pulse', 'blood_pressure', 'temperature'],
         dense_shape=[3, 2, 3])
     expected_delta_time = [[[1], [2]], [[20], [0]], [[200], [0]]]
     new_delta_time = input_fn._intersect_indices(delta_time, obs_harm_code)
     with self.test_session() as sess:
         acutal_delta_time = sess.run(new_delta_time)
         self.assertAllClose(expected_delta_time, acutal_delta_time)