Esempio n. 1
0
    def __init__(self,
                 learning_rate,
                 preconditioner_decay_rate=0.95,
                 data_size=1,
                 burnin=25,
                 diagonal_bias=1e-8,
                 name=None,
                 parallel_iterations=10):
        default_name = 'StochasticGradientLangevinDynamics'
        with tf.name_scope(name or default_name):
            if tf.executing_eagerly():
                raise NotImplementedError(
                    'Eager execution currently not supported for '
                    ' SGLD optimizer.')

            self._preconditioner_decay_rate = tf.convert_to_tensor(
                preconditioner_decay_rate, name='preconditioner_decay_rate')
            self._data_size = tf.convert_to_tensor(data_size, name='data_size')
            self._burnin = tf.convert_to_tensor(burnin,
                                                name='burnin',
                                                dtype=dtype_util.common_dtype(
                                                    [burnin],
                                                    dtype_hint=tf.int64))
            self._diagonal_bias = tf.convert_to_tensor(diagonal_bias,
                                                       name='diagonal_bias')
            # TODO(b/124800185): Consider migrating `learning_rate` to be a
            # hyperparameter handled by the base Optimizer class. This would allow
            # users to plug in a `tf.keras.optimizers.schedules.LearningRateSchedule`
            # object in addition to Tensors.
            self._learning_rate = tf.convert_to_tensor(learning_rate,
                                                       name='learning_rate')
            self._parallel_iterations = parallel_iterations

            self._preconditioner_decay_rate = distribution_util.with_dependencies(
                [
                    assert_util.assert_non_negative(
                        self._preconditioner_decay_rate,
                        message=
                        '`preconditioner_decay_rate` must be non-negative'),
                    assert_util.assert_less_equal(
                        self._preconditioner_decay_rate,
                        1.,
                        message='`preconditioner_decay_rate` must be at most 1.'
                    ),
                ], self._preconditioner_decay_rate)

            self._data_size = distribution_util.with_dependencies([
                assert_util.assert_greater(
                    self._data_size,
                    0,
                    message='`data_size` must be greater than zero')
            ], self._data_size)

            self._burnin = distribution_util.with_dependencies([
                assert_util.assert_non_negative(
                    self._burnin, message='`burnin` must be non-negative'),
                assert_util.assert_integer(
                    self._burnin, message='`burnin` must be an integer')
            ], self._burnin)

            self._diagonal_bias = distribution_util.with_dependencies([
                assert_util.assert_non_negative(
                    self._diagonal_bias,
                    message='`diagonal_bias` must be non-negative')
            ], self._diagonal_bias)

            super(StochasticGradientLangevinDynamics,
                  self).__init__(name=name or default_name)
    def __init__(self,
                 batch_size,
                 total_num_examples,
                 max_learning_rate=1.,
                 preconditioner_decay_rate=0.95,
                 burnin=25,
                 burnin_max_learning_rate=1e-6,
                 use_single_learning_rate=False,
                 name=None):
        default_name = 'VariationalSGD'
        with tf.name_scope(name or default_name):
            self._preconditioner_decay_rate = tf.convert_to_tensor(
                preconditioner_decay_rate, name='preconditioner_decay_rate')
            self._batch_size = tf.convert_to_tensor(batch_size,
                                                    name='batch_size')
            self._total_num_examples = tf.convert_to_tensor(
                total_num_examples, name='total_num_examples')

            self._burnin = tf.convert_to_tensor(burnin,
                                                name='burnin',
                                                dtype=dtype_util.common_dtype(
                                                    [burnin],
                                                    dtype_hint=tf.int64))
            self._burnin_max_learning_rate = tf.convert_to_tensor(
                burnin_max_learning_rate, name='burnin_max_learning_rate')
            self._max_learning_rate = tf.convert_to_tensor(
                max_learning_rate, name='max_learning_rate')
            self._use_single_learning_rate = use_single_learning_rate

            self._preconditioner_decay_rate = distribution_util.with_dependencies(
                [
                    assert_util.assert_non_negative(
                        self._preconditioner_decay_rate,
                        message=
                        '`preconditioner_decay_rate` must be non-negative'),
                    assert_util.assert_less_equal(
                        self._preconditioner_decay_rate,
                        1.,
                        message='`preconditioner_decay_rate` must be at most 1.'
                    ),
                ], self._preconditioner_decay_rate)

            self._batch_size = distribution_util.with_dependencies([
                assert_util.assert_greater(
                    self._batch_size,
                    0,
                    message='`batch_size` must be greater than zero')
            ], self._batch_size)

            self._total_num_examples = distribution_util.with_dependencies([
                assert_util.assert_greater(
                    self._total_num_examples,
                    0,
                    message='`total_num_examples` must be greater than zero')
            ], self._total_num_examples)

            self._burnin = distribution_util.with_dependencies([
                assert_util.assert_non_negative(
                    self._burnin, message='`burnin` must be non-negative'),
                assert_util.assert_integer(
                    self._burnin, message='`burnin` must be an integer')
            ], self._burnin)

            self._burnin_max_learning_rate = distribution_util.with_dependencies(
                [
                    assert_util.assert_non_negative(
                        self._burnin_max_learning_rate,
                        message=
                        '`burnin_max_learning_rate` must be non-negative')
                ], self._burnin_max_learning_rate)

            self._max_learning_rate = distribution_util.with_dependencies([
                assert_util.assert_non_negative(
                    self._max_learning_rate,
                    message='`max_learning_rate` must be non-negative')
            ], self._max_learning_rate)

            super(VariationalSGD, self).__init__(name=name or default_name)