Beispiel #1
0
    def test_multiple_segment_join(self):
        inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
        segment_ids_1 = [1, 0, 1]
        num_segments_1 = 2
        separator_1 = ':'
        output_array_1 = [['Y', '6', '6'], ['Y:p', 'q:G', 'c:a']]

        res = self.evaluate(
            string_ops.unsorted_segment_join(inputs=inputs,
                                             segment_ids=segment_ids_1,
                                             num_segments=num_segments_1,
                                             separator=separator_1))
        self.assertAllEqualUnicode(res, output_array_1)
        self.assertAllEqual(res.shape, np.array(output_array_1).shape)

        segment_ids_2 = [1, 1]
        num_segments_2 = 2
        separator_2 = ''
        output_array_2 = [['', '', ''], ['YY:p', '6q:G', '6c:a']]

        res = self.evaluate(
            string_ops.unsorted_segment_join(inputs=res,
                                             segment_ids=segment_ids_2,
                                             num_segments=num_segments_2,
                                             separator=separator_2))
        self.assertAllEqualUnicode(res, output_array_2)
        self.assertAllEqual(res.shape, np.array(output_array_2).shape)
Beispiel #2
0
 def test_multiple_cases_with_different_dims(self, inputs, segment_ids,
                                             num_segments, separator,
                                             output_array):
     res = self.evaluate(
         string_ops.unsorted_segment_join(inputs=inputs,
                                          segment_ids=segment_ids,
                                          num_segments=num_segments,
                                          separator=separator))
     self.assertAllEqualUnicode(res, output_array)
     self.assertAllEqual(res.shape, np.array(output_array).shape)
Beispiel #3
0
 def test_fail_negative_segment_id(self):
     inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
     segment_ids = [-1, 0, -1]
     num_segments = 1
     separator = ':'
     with self.assertRaises(errors_impl.InvalidArgumentError):
         self.evaluate(
             string_ops.unsorted_segment_join(inputs=inputs,
                                              segment_ids=segment_ids,
                                              num_segments=num_segments,
                                              separator=separator))
Beispiel #4
0
 def test_empty_input(self):
     inputs = np.array([], dtype=np.string_)
     segment_ids = [1, 0, 1]
     num_segments = 2
     separator = ':'
     with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
         self.evaluate(
             string_ops.unsorted_segment_join(inputs=inputs,
                                              segment_ids=segment_ids,
                                              num_segments=num_segments,
                                              separator=separator))
Beispiel #5
0
 def test_fail_segment_id_empty_input_non_empty(self):
     inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
     segment_ids = np.array([], dtype=np.int32)
     num_segments = 2
     separator = ':'
     with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
         self.evaluate(
             string_ops.unsorted_segment_join(inputs=inputs,
                                              segment_ids=segment_ids,
                                              num_segments=num_segments,
                                              separator=separator))
Beispiel #6
0
    def test_fail_segment_id_dim_does_not_match(self):
        inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
        segment_ids = [1, 0, 1, 1]
        num_segments = 2
        separator = ':'

        if not context.executing_eagerly():
            with self.assertRaises(ValueError):
                self.evaluate(
                    string_ops.unsorted_segment_join(inputs=inputs,
                                                     segment_ids=segment_ids,
                                                     num_segments=num_segments,
                                                     separator=separator))
        else:
            with self.assertRaises(errors_impl.InvalidArgumentError):
                self.evaluate(
                    string_ops.unsorted_segment_join(inputs=inputs,
                                                     segment_ids=segment_ids,
                                                     num_segments=num_segments,
                                                     separator=separator))
Beispiel #7
0
 def testSeparator(self, separator, output_array):
     inputs = ['this', 'is', 'a', 'test']
     segment_ids = [0, 0, 0, 0]
     num_segments = 1
     res = self.evaluate(
         string_ops.unsorted_segment_join(inputs=inputs,
                                          segment_ids=segment_ids,
                                          num_segments=num_segments,
                                          separator=separator))
     self.assertAllEqual(res.shape, np.array(output_array).shape)
     self.assertAllEqualUnicode(res, output_array)
Beispiel #8
0
 def test_segment_id_and_input_empty(self):
     inputs = np.array([], dtype=np.string_)
     segment_ids = np.array([], dtype=np.int32)
     num_segments = 3
     separator = ':'
     output_array = ['', '', '']
     res = self.evaluate(
         string_ops.unsorted_segment_join(inputs=inputs,
                                          segment_ids=segment_ids,
                                          num_segments=num_segments,
                                          separator=separator))
     self.assertAllEqual(res.shape, np.array(output_array).shape)
     self.assertAllEqualUnicode(res, output_array)
Beispiel #9
0
    def test_basic_np_array(self):
        inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
        segment_ids = [1, 0, 1]
        num_segments = 2
        separator = ':'
        output_array = [['Y', '6', '6'], ['Y:p', 'q:G', 'c:a']]

        res = self.evaluate(
            string_ops.unsorted_segment_join(inputs=inputs,
                                             segment_ids=segment_ids,
                                             num_segments=num_segments,
                                             separator=separator))
        self.assertAllEqual(res.shape, np.array(output_array).shape)
        self.assertAllEqualUnicode(res, output_array)
Beispiel #10
0
    def test_type_check(self):
        inputs = [['Y', 'q', 'c'], ['Y', '6', '6'], ['p', 'G', 'a']]
        segment_ids = np.array([1, 0, 1], dtype=np.int32)
        num_segments = np.array(2, dtype=np.int32)
        separator = ':'
        output_array = [['Y', '6', '6'], ['Y:p', 'q:G', 'c:a']]

        res = self.evaluate(
            string_ops.unsorted_segment_join(inputs=inputs,
                                             segment_ids=segment_ids,
                                             num_segments=num_segments,
                                             separator=separator))
        self.assertAllEqual(res.shape, np.array(output_array).shape)
        self.assertAllEqualUnicode(res, output_array)

        segment_ids = np.array([1, 0, 1], dtype=np.int64)
        num_segments = np.array(2, dtype=np.int64)
        res = self.evaluate(
            string_ops.unsorted_segment_join(inputs=inputs,
                                             segment_ids=segment_ids,
                                             num_segments=num_segments,
                                             separator=separator))
        self.assertAllEqual(res.shape, np.array(output_array).shape)
        self.assertAllEqualUnicode(res, output_array)
Beispiel #11
0
    def test_basic_tensor(self):
        inputs = constant_op.constant([['Y', 'q', 'c'], ['Y', '6', '6'],
                                       ['p', 'G', 'a']])
        segment_ids = constant_op.constant([1, 0, 1])
        num_segments = 2
        separator = ':'
        output_array = constant_op.constant([['Y', '6', '6'],
                                             ['Y:p', 'q:G', 'c:a']])

        res = self.evaluate(
            string_ops.unsorted_segment_join(inputs=inputs,
                                             segment_ids=segment_ids,
                                             num_segments=num_segments,
                                             separator=separator))
        self.assertAllEqual(res, output_array)
        self.assertAllEqual(res.shape, output_array.get_shape())