Пример #1
0
    def call(self, inputs):
        def _bucketize_op(bins):
            bins = [math_ops.cast(bins, dtypes.float32)]
            return lambda inputs: gen_boosted_trees_ops.BoostedTreesBucketize(  # pylint: disable=g-long-lambda
                float_values=[math_ops.cast(inputs, dtypes.float32)],
                bucket_boundaries=bins)[0]

        if tf_utils.is_ragged(inputs):
            integer_buckets = ragged_functional_ops.map_flat_values(
                _bucketize_op(array_ops.squeeze(self.bins)), inputs)
            # Ragged map_flat_values doesn't touch the non-values tensors in the
            # ragged composite tensor. If this op is the only op a Keras model,
            # this can cause errors in Graph mode, so wrap the tensor in an identity.
            return array_ops.identity(integer_buckets)
        elif isinstance(inputs, sparse_tensor.SparseTensor):
            integer_buckets = gen_boosted_trees_ops.BoostedTreesBucketize(
                float_values=[math_ops.cast(inputs.values, dtypes.float32)],
                bucket_boundaries=[
                    math_ops.cast(array_ops.squeeze(self.bins), dtypes.float32)
                ])[0]
            return sparse_tensor.SparseTensor(
                indices=array_ops.identity(inputs.indices),
                values=integer_buckets,
                dense_shape=array_ops.identity(inputs.dense_shape))
        else:
            input_shape = inputs.get_shape()
            if any(dim is None for dim in input_shape.as_list()[1:]):
                raise NotImplementedError(
                    "Discretization Layer requires known non-batch shape,"
                    "found {}".format(input_shape))

            reshaped = array_ops.reshape(inputs, [
                -1,
                gen_math_ops.Prod(input=input_shape.as_list()[1:], axis=0)
            ])

            return array_ops.reshape(
                control_flow_ops.vectorized_map(
                    _bucketize_op(array_ops.squeeze(self.bins)), reshaped),
                array_ops.constant([-1] + input_shape.as_list()[1:]))
Пример #2
0
 def _bucketize_op(bins):
     bins = [math_ops.cast(bins, dtypes.float32)]
     return lambda inputs: gen_boosted_trees_ops.BoostedTreesBucketize(  # pylint: disable=g-long-lambda
         float_values=[math_ops.cast(inputs, dtypes.float32)],
         bucket_boundaries=bins)[0]
Пример #3
0
 def _bucketize_fn(inputs):
     return gen_boosted_trees_ops.BoostedTreesBucketize(
         float_values=[math_ops.cast(inputs, dtypes.float32)],
         bucket_boundaries=bins)[0]