예제 #1
0
 def testPerBatchBudgetTrimmer(self,
                               max_seq_length,
                               segments,
                               expected,
                               axis=-1,
                               descr=None):
   max_seq_length = constant_op.constant(max_seq_length)
   trimmer = trimmer_ops.WaterfallTrimmer(max_seq_length, axis=axis)
   segments = [ragged_factory_ops.constant(seg) for seg in segments]
   expected = [ragged_factory_ops.constant(exp) for exp in expected]
   actual = trimmer.trim(segments)
   for expected_seg, actual_seg in zip(expected, actual):
     self.assertAllEqual(expected_seg, actual_seg)
예제 #2
0
 def testGenerateMask(self,
                      segments,
                      max_seq_length,
                      expected,
                      axis=-1,
                      descr=None):
   max_seq_length = constant_op.constant(max_seq_length)
   segments = [ragged_factory_ops.constant(i) for i in segments]
   expected = [ragged_factory_ops.constant(i) for i in expected]
   trimmer = trimmer_ops.WaterfallTrimmer(max_seq_length, axis=axis)
   actual = trimmer.generate_mask(segments)
   for expected_mask, actual_mask in zip(expected, actual):
     self.assertAllEqual(actual_mask, expected_mask)