Exemplo n.º 1
0
    def test_get_sparse_tensors_dynamic_zero_length(self):
        """Tests _get_sparse_tensors with a dynamic sequence length."""
        with ops.Graph().as_default():
            inputs = sparse_tensor.SparseTensorValue(indices=np.zeros((0, 2)),
                                                     values=[],
                                                     dense_shape=(2, 0))
            expected = sparse_tensor.SparseTensorValue(indices=np.zeros(
                (0, 3)),
                                                       values=np.array(
                                                           (), dtype=np.int64),
                                                       dense_shape=(2, 0, 1))
            column = sfc.sequence_categorical_column_with_vocabulary_file(
                key='aaa',
                vocabulary_file=self._wire_vocabulary_file_name,
                vocabulary_size=self._wire_vocabulary_size)
            input_placeholder_shape = list(inputs.dense_shape)
            # Make second dimension (sequence length) dynamic.
            input_placeholder_shape[1] = None
            input_placeholder = array_ops.sparse_placeholder(
                dtypes.string, shape=input_placeholder_shape)
            id_weight_pair = _get_sparse_tensors(column,
                                                 {'aaa': input_placeholder})

            self.assertIsNone(id_weight_pair.weight_tensor)
            with _initialized_session() as sess:
                result = id_weight_pair.id_tensor.eval(
                    session=sess, feed_dict={input_placeholder: inputs})
                _assert_sparse_tensor_value(self, expected, result)
Exemplo n.º 2
0
    def test_get_sparse_tensors(self, inputs_args, expected_args):
        inputs = sparse_tensor.SparseTensorValue(**inputs_args)
        expected = sparse_tensor.SparseTensorValue(**expected_args)
        column = sfc.sequence_categorical_column_with_vocabulary_file(
            key='aaa',
            vocabulary_file=self._wire_vocabulary_file_name,
            vocabulary_size=self._wire_vocabulary_size)

        id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs})

        self.assertIsNone(id_weight_pair.weight_tensor)
        self.evaluate(variables_lib.global_variables_initializer())
        self.evaluate(lookup_ops.tables_initializer())
        _assert_sparse_tensor_value(self, expected,
                                    self.evaluate(id_weight_pair.id_tensor))
Exemplo n.º 3
0
def _get_sequence_categorical_column(params: dict) -> fc.SequenceCategoricalColumn:
    key = params['key']
    if 'vocabulary' in params.keys():
        feature = sfc.sequence_categorical_column_with_vocabulary_list(key,
                                                                       vocabulary_list=_parse_vocabulary(
                                                                           params['vocabulary']),
                                                                       default_value=0)
    elif 'bucket_size' in params.keys():
        feature = sfc.sequence_categorical_column_with_hash_bucket(
            key, hash_bucket_size=params['bucket_size'])
    elif 'file' in params.keys():
        feature = sfc.sequence_categorical_column_with_vocabulary_file(key,
                                                                       vocabulary_file=params['file'],
                                                                       default_value=0)
    elif 'num_buckets' in params.keys():
        feature = sfc.sequence_categorical_column_with_identity(key,
                                                                num_buckets=params['num_buckets'])
    else:
        raise Exception("params error")

    return feature