def test_pad_sequences_vector(self): a = [[[1, 1]], [[2, 1], [2, 2]], [[3, 1], [3, 2], [3, 3]]] # test padding b = data_utils.pad_sequences(a, maxlen=3, padding='pre') self.assertAllClose( b, [[[0, 0], [0, 0], [1, 1]], [[0, 0], [2, 1], [2, 2]], [[3, 1], [3, 2], [3, 3]]]) b = data_utils.pad_sequences(a, maxlen=3, padding='post') self.assertAllClose( b, [[[1, 1], [0, 0], [0, 0]], [[2, 1], [2, 2], [0, 0]], [[3, 1], [3, 2], [3, 3]]]) # test truncating b = data_utils.pad_sequences(a, maxlen=2, truncating='pre') self.assertAllClose( b, [[[0, 0], [1, 1]], [[2, 1], [2, 2]], [[3, 2], [3, 3]]]) b = data_utils.pad_sequences(a, maxlen=2, truncating='post') self.assertAllClose( b, [[[0, 0], [1, 1]], [[2, 1], [2, 2]], [[3, 1], [3, 2]]]) # test value b = data_utils.pad_sequences(a, maxlen=3, value=1) self.assertAllClose( b, [[[1, 1], [1, 1], [1, 1]], [[1, 1], [2, 1], [2, 2]], [[3, 1], [3, 2], [3, 3]]])
def get_data( self, count=(_GLOBAL_BATCH_SIZE * _EVAL_STEPS), min_words=5, max_words=10, max_word_id=19, num_classes=2, ): distribution = [] for _ in range(num_classes): dist = np.abs(np.random.randn(max_word_id)) dist /= np.sum(dist) distribution.append(dist) features = [] labels = [] for _ in range(count): label = np.random.randint(0, num_classes, size=1)[0] num_words = np.random.randint(min_words, max_words, size=1)[0] word_ids = np.random.choice(max_word_id, size=num_words, replace=True, p=distribution[label]) word_ids = word_ids labels.append(label) features.append(word_ids) features = data_utils.pad_sequences(features, maxlen=max_words) x_train = np.asarray(features, dtype=np.float32) y_train = np.asarray(labels, dtype=np.int32).reshape((count, 1)) x_predict = x_train[:_GLOBAL_BATCH_SIZE] return x_train, y_train, x_predict
def test_pad_sequences(self): a = [[1], [1, 2], [1, 2, 3]] # test padding b = data_utils.pad_sequences(a, maxlen=3, padding="pre") self.assertAllClose(b, [[0, 0, 1], [0, 1, 2], [1, 2, 3]]) b = data_utils.pad_sequences(a, maxlen=3, padding="post") self.assertAllClose(b, [[1, 0, 0], [1, 2, 0], [1, 2, 3]]) # test truncating b = data_utils.pad_sequences(a, maxlen=2, truncating="pre") self.assertAllClose(b, [[0, 1], [1, 2], [2, 3]]) b = data_utils.pad_sequences(a, maxlen=2, truncating="post") self.assertAllClose(b, [[0, 1], [1, 2], [1, 2]]) # test value b = data_utils.pad_sequences(a, maxlen=3, value=1) self.assertAllClose(b, [[1, 1, 1], [1, 1, 2], [1, 2, 3]])
def test_pad_sequences_str(self): a = [["1"], ["1", "2"], ["1", "2", "3"]] # test padding b = data_utils.pad_sequences( a, maxlen=3, padding="pre", value="pad", dtype=object ) self.assertAllEqual( b, [["pad", "pad", "1"], ["pad", "1", "2"], ["1", "2", "3"]] ) b = data_utils.pad_sequences( a, maxlen=3, padding="post", value="pad", dtype="<U3" ) self.assertAllEqual( b, [["1", "pad", "pad"], ["1", "2", "pad"], ["1", "2", "3"]] ) # test truncating b = data_utils.pad_sequences( a, maxlen=2, truncating="pre", value="pad", dtype=object ) self.assertAllEqual(b, [["pad", "1"], ["1", "2"], ["2", "3"]]) b = data_utils.pad_sequences( a, maxlen=2, truncating="post", value="pad", dtype="<U3" ) self.assertAllEqual(b, [["pad", "1"], ["1", "2"], ["1", "2"]]) with self.assertRaisesRegex( ValueError, "`dtype` int32 is not compatible with " ): data_utils.pad_sequences( a, maxlen=2, truncating="post", value="pad" )
def test_pad_sequences_str(self): a = [['1'], ['1', '2'], ['1', '2', '3']] # test padding b = data_utils.pad_sequences(a, maxlen=3, padding='pre', value='pad', dtype=object) self.assertAllEqual( b, [['pad', 'pad', '1'], ['pad', '1', '2'], ['1', '2', '3']]) b = data_utils.pad_sequences(a, maxlen=3, padding='post', value='pad', dtype='<U3') self.assertAllEqual( b, [['1', 'pad', 'pad'], ['1', '2', 'pad'], ['1', '2', '3']]) # test truncating b = data_utils.pad_sequences(a, maxlen=2, truncating='pre', value='pad', dtype=object) self.assertAllEqual(b, [['pad', '1'], ['1', '2'], ['2', '3']]) b = data_utils.pad_sequences(a, maxlen=2, truncating='post', value='pad', dtype='<U3') self.assertAllEqual(b, [['pad', '1'], ['1', '2'], ['1', '2']]) with self.assertRaisesRegex(ValueError, '`dtype` int32 is not compatible with '): data_utils.pad_sequences(a, maxlen=2, truncating='post', value='pad')