Exemplo n.º 1
0
 def test_mask_by_partial_sequence_length_value_error(self):
   with self.assertRaisesRegex(
       ValueError,
       'partial_sequence_lengths is expected when target_length is not None'):
     metrics.mask_by_partial_sequence_length(
         tensors=self.tensors,
         partial_sequence_lengths=None,
         target_length=1)
Exemplo n.º 2
0
    def test_mask_by_partial_sequence_length_no_change(self):
        masked_tensors = metrics.mask_by_partial_sequence_length(
            self.tensors, partial_sequence_lengths=None, target_length=None)

        with self.test_session() as sess:
            masked_tensors_values = sess.run(masked_tensors)
            self.assertAllEqual(masked_tensors_values,
                                ([[1, 2], [3, 4]], [[5, 6], [7, 8]]))
Exemplo n.º 3
0
  def test_mask_by_partial_sequence_length(
      self, partial_sequence_lengths, expected_values):
    masked_tensors = metrics.mask_by_partial_sequence_length(
        self.tensors,
        partial_sequence_lengths=tf.constant(partial_sequence_lengths),
        target_length=1)

    with self.test_session() as sess:
      masked_tensors_values = sess.run(masked_tensors)
      self.assertAllEqual(masked_tensors_values, expected_values)
Exemplo n.º 4
0
  def test_mask_by_partial_sequence_length_empty_output_tensors(self):
    masked_tensors = metrics.mask_by_partial_sequence_length(
        self.tensors,
        partial_sequence_lengths=tf.constant([42, 42]),
        target_length=1)

    with self.test_session() as sess:
      masked_tensors_values = sess.run(masked_tensors)
      self.assertLen(masked_tensors_values, 2)
      self.assertAllEqual(masked_tensors_values[0].shape, (0, 2))
      self.assertAllEqual(masked_tensors_values[1].shape, (0, 2))