def test_not_equal_indices_and_seg_ids_num(self):
     with self.session(use_gpu=use_gpu, config=default_config):
         data = constant_op.constant(list(range(20)), dtype=dtypes.float32)
         data = array_ops.reshape(data, (10, 2))
         indices = constant_op.constant(list(range(6)), dtype=dtypes.int32)
         segment_ids = constant_op.constant(list(range(7)),
                                            dtype=dtypes.int32)
         with self.assertRaises((ValueError, errors.InvalidArgumentError)):
             target = de_math.sparse_segment_sum(data, indices, segment_ids)
             self.evaluate(target)
 def forward_compute(self, data, indices, segment_ids, num_segments=None):
     result = de_math.sparse_segment_sum(data,
                                         indices,
                                         segment_ids,
                                         num_segments=num_segments)
     expected = math_ops.sparse_segment_sum(data,
                                            indices,
                                            segment_ids,
                                            num_segments=num_segments)
     return result, expected
Example #3
0
 def backward_compute(self, data, indices, segment_ids, num_segments=None):
     with backprop.GradientTape(persistent=True) as tape:
         tape.watch(data)
         result = de_math.sparse_segment_sum(data,
                                             indices,
                                             segment_ids,
                                             num_segments=num_segments)
         expected = math_ops.sparse_segment_sum(data,
                                                indices,
                                                segment_ids,
                                                num_segments=num_segments)
     result = tape.gradient(result, data)
     expected = tape.gradient(expected, data)
     return result, expected