def __init__(
      self,
      units,
      activation=None,
      activity_regularizer=None,
      kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(),
      kernel_posterior_tensor_fn=lambda d: d.sample(),
      kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn,
      kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
      bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=True),  # pylint: disable=line-too-long
      bias_posterior_tensor_fn=lambda d: d.sample(),
      bias_prior_fn=None,
      bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
      **kwargs):
    # pylint: disable=g-doc-args
    """Construct layer.

    Args:
      ${args}
    """
    # pylint: enable=g-doc-args
    super(_DenseVariational, self).__init__(
        activity_regularizer=activity_regularizer,
        **kwargs)
    self.units = units
    self.activation = tf.keras.activations.get(activation)
    self.input_spec = tf.layers.InputSpec(min_ndim=2)
    self.kernel_posterior_fn = kernel_posterior_fn
    self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn
    self.kernel_prior_fn = kernel_prior_fn
    self.kernel_divergence_fn = kernel_divergence_fn
    self.bias_posterior_fn = bias_posterior_fn
    self.bias_posterior_tensor_fn = bias_posterior_tensor_fn
    self.bias_prior_fn = bias_prior_fn
    self.bias_divergence_fn = bias_divergence_fn
Exemple #2
0
  def __init__(
      self,
      units,
      activation=None,
      activity_regularizer=None,
      kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(),
      kernel_posterior_tensor_fn=lambda d: d.sample(),
      kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn,
      kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
      bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(is_singular=True),  # pylint: disable=line-too-long
      bias_posterior_tensor_fn=lambda d: d.sample(),
      bias_prior_fn=None,
      bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
      **kwargs):
    # pylint: disable=g-doc-args
    """Construct layer.

    Args:
      @{args}
    """
    # pylint: enable=g-doc-args
    super(_DenseVariational, self).__init__(
        activity_regularizer=activity_regularizer,
        **kwargs)
    self.units = units
    self.activation = tf.keras.activations.get(activation)
    self.input_spec = tf.layers.InputSpec(min_ndim=2)
    self.kernel_posterior_fn = kernel_posterior_fn
    self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn
    self.kernel_prior_fn = kernel_prior_fn
    self.kernel_divergence_fn = kernel_divergence_fn
    self.bias_posterior_fn = bias_posterior_fn
    self.bias_posterior_tensor_fn = bias_posterior_tensor_fn
    self.bias_prior_fn = bias_prior_fn
    self.bias_divergence_fn = bias_divergence_fn
Exemple #3
0
    def __init__(
            self,
            filters,
            kernel_size,
            strides=(1, 1),
            padding='valid',
            output_padding=None,
            data_format='channels_last',
            dilation_rate=(1, 1),
            activation=None,
            activity_regularizer=None,
            kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(),
            kernel_posterior_tensor_fn=lambda d: d.sample(),
            kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn,
            kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
            bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(
                is_singular=True),
            bias_posterior_tensor_fn=lambda d: d.sample(),
            bias_prior_fn=None,
            bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
            **kwargs):
        super(Conv2DTransposeReparameterization, self).__init__(
            filters=filters,
            kernel_size=kernel_size,
            strides=strides,
            padding=padding,
            data_format=data_format,
            dilation_rate=dilation_rate,
            activation=tf.keras.activations.get(activation),
            activity_regularizer=activity_regularizer,
            kernel_posterior_fn=kernel_posterior_fn,
            kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
            kernel_prior_fn=kernel_prior_fn,
            kernel_divergence_fn=kernel_divergence_fn,
            bias_posterior_fn=bias_posterior_fn,
            bias_posterior_tensor_fn=bias_posterior_tensor_fn,
            bias_prior_fn=bias_prior_fn,
            bias_divergence_fn=bias_divergence_fn,
            **kwargs)

        self.output_padding = output_padding
        if self.output_padding is not None:
            self.output_padding = conv_utils.normalize_tuple(
                self.output_padding, 2, 'output_padding')
            for stride, out_pad in zip(self.strides, self.output_padding):
                if out_pad >= stride:
                    raise ValueError('Stride ' + str(self.strides) +
                                     ' must be '
                                     'greater than output padding ' +
                                     str(self.output_padding))
Exemple #4
0
 def __init__(
     self,
     rank,
     filters,
     kernel_size,
     is_mc,
     strides=1,
     padding="valid",
     data_format="channels_last",
     dilation_rate=1,
     activation=None,
     activity_regularizer=None,
     kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(),
     kernel_posterior_tensor_fn=lambda d: d.sample(),
     kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn,
     kernel_divergence_fn=(lambda q, p, ignore: kl_lib.kl_divergence(q, p)),
     bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(
         is_singular=True
     ),
     bias_posterior_tensor_fn=lambda d: d.sample(),
     bias_prior_fn=None,
     bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
     **kwargs
 ):
     super(_ConvVariational, self).__init__(
         activity_regularizer=activity_regularizer, **kwargs
     )
     self.rank = rank
     self.is_mc = is_mc
     self.filters = filters
     self.kernel_size = tf_layers_util.normalize_tuple(
         kernel_size, rank, "kernel_size"
     )
     self.strides = tf_layers_util.normalize_tuple(strides, rank, "strides")
     self.padding = tf_layers_util.normalize_padding(padding)
     self.data_format = tf_layers_util.normalize_data_format(data_format)
     self.dilation_rate = tf_layers_util.normalize_tuple(
         dilation_rate, rank, "dilation_rate"
     )
     self.activation = tf.keras.activations.get(activation)
     self.input_spec = tf.keras.layers.InputSpec(ndim=self.rank + 2)
     self.kernel_posterior_fn = kernel_posterior_fn
     self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn
     self.kernel_prior_fn = kernel_prior_fn
     self.kernel_divergence_fn = kernel_divergence_fn
     self.bias_posterior_fn = bias_posterior_fn
     self.bias_posterior_tensor_fn = bias_posterior_tensor_fn
     self.bias_prior_fn = bias_prior_fn
     self.bias_divergence_fn = bias_divergence_fn
def fullnet(numclass,
            activation=tf.nn.relu,
            priorstd=1,
            poststd=None,
            layer_sizes=[100, 50, 10],
            isBay=False):

    priorfn = gen_priordist(std=priorstd)
    if poststd is None:
        postfn = tfp_layers_util.default_mean_field_normal_fn()
    else:
        postfn = gen_postdist(std=poststd)

    model = tf.keras.Sequential()
    for i in range(len(layer_sizes[:-1])):
        if isBay:
            layer = tfp.layers.DenseFlipout(layer_sizes[i],
                                            activation=activation,
                                            kernel_prior_fn=priorfn,
                                            kernel_posterior_fn=postfn)
        else:
            layer = tf.keras.layers.Dense(layer_sizes[i],
                                          activation=activation)
        model.add(layer)

    if isBay:
        model.add(
            tfp.layers.DenseFlipout(numclass,
                                    kernel_prior_fn=priorfn,
                                    kernel_posterior_fn=postfn))
    else:
        model.add(tf.keras.layers.Dense(numclass))

    return model
    def __init__(
            self, units,
            activation=None,
            activity_regularizer=None,
            client_weight=1.,
            trainable=True,
            kernel_posterior_fn=None,
            kernel_posterior_tensor_fn=(lambda d: d.sample()),
            kernel_prior_fn=None,
            kernel_divergence_fn=(
                    lambda q, p, ignore: tfd.kl_divergence(q, p)),
            bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(
                is_singular=True),
            bias_posterior_tensor_fn=(lambda d: d.sample()),
            bias_prior_fn=None,
            bias_divergence_fn=(lambda q, p, ignore: tfd.kl_divergence(q, p)),
            **kwargs):

        self.untransformed_scale_initializer = None
        if 'untransformed_scale_initializer' in kwargs:
            self.untransformed_scale_initializer = \
                kwargs.pop('untransformed_scale_initializer')
        self.loc_initializer = None
        if 'loc_initializer' in kwargs:
            self.loc_initializer = \
                kwargs.pop('loc_initializer')

        self.delta_percentile = kwargs.pop('delta_percentile', None)

        if kernel_posterior_fn is None:
            kernel_posterior_fn = self.renormalize_natural_mean_field_normal_fn
        if kernel_prior_fn is None:
            kernel_prior_fn = self.natural_tensor_multivariate_normal_fn

        super(DenseSharedNatural, self).\
            __init__(units,
                     activation=activation,
                     activity_regularizer=activity_regularizer,
                     trainable=trainable,
                     kernel_posterior_fn=kernel_posterior_fn,
                     kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
                     kernel_prior_fn=kernel_prior_fn,
                     kernel_divergence_fn=kernel_divergence_fn,
                     bias_posterior_fn=bias_posterior_fn,
                     bias_posterior_tensor_fn=bias_posterior_tensor_fn,
                     bias_prior_fn=bias_prior_fn,
                     bias_divergence_fn=bias_divergence_fn,
                     **kwargs)

        self.client_weight = client_weight
        self.delta_function = tf.subtract
        if self.delta_percentile and not activation == 'softmax':
            self.delta_function = sparse_delta_function(self.delta_percentile)
            print(self, activation, 'using delta sparisfication')
        self.apply_delta_function = tf.add
        self.client_variable_dict = {}
        self.client_center_variable_dict = {}
        self.server_variable_dict = {}
Exemple #7
0
 def __init__(
     self,
     rank,
     filters,
     kernel_size,
     is_mc,
     strides=1,
     padding="valid",
     data_format="channels_last",
     dilation_rate=1,
     activation=None,
     activity_regularizer=None,
     kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(),
     kernel_posterior_tensor_fn=lambda d: d.sample(),
     kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn,
     kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
     bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(
         is_singular=True
     ),
     bias_posterior_tensor_fn=lambda d: d.sample(),
     bias_prior_fn=None,
     bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
     **kwargs
 ):
     super(_ConvReparameterization, self).__init__(
         rank=rank,
         filters=filters,
         kernel_size=kernel_size,
         strides=strides,
         padding=padding,
         is_mc=is_mc,
         data_format=data_format,
         dilation_rate=dilation_rate,
         activation=tf.keras.activations.get(activation),
         activity_regularizer=activity_regularizer,
         kernel_posterior_fn=kernel_posterior_fn,
         kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
         kernel_prior_fn=kernel_prior_fn,
         kernel_divergence_fn=kernel_divergence_fn,
         bias_posterior_fn=bias_posterior_fn,
         bias_posterior_tensor_fn=bias_posterior_tensor_fn,
         bias_prior_fn=bias_prior_fn,
         bias_divergence_fn=bias_divergence_fn,
         **kwargs
     )
    def __init__(
            self,
            filters,
            kernel_size,
            strides=1,
            padding='valid',
            client_weight=1.,
            data_format='channels_last',
            dilation_rate=1,
            activation=None,
            activity_regularizer=None,
            kernel_posterior_fn=None,
            kernel_posterior_tensor_fn=(lambda d: d.sample()),
            kernel_prior_fn=None,
            kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
            bias_posterior_fn=
            tfp_layers_util.default_mean_field_normal_fn(is_singular=True),
            bias_posterior_tensor_fn=lambda d: d.sample(),
            bias_prior_fn=None,
            bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
            **kwargs):

        self.untransformed_scale_initializer = None
        if 'untransformed_scale_initializer' in kwargs:
            self.untransformed_scale_initializer = \
                kwargs.pop('untransformed_scale_initializer')

        if kernel_posterior_fn is None:
            kernel_posterior_fn = self.renormalize_natural_mean_field_normal_fn
        if kernel_prior_fn is None:
            kernel_prior_fn = self.natural_tensor_multivariate_normal_fn

        super(Conv1DVirtualNatural, self).__init__(
            filters=filters,
            kernel_size=kernel_size,
            strides=strides,
            padding=padding,
            data_format=data_format,
            dilation_rate=dilation_rate,
            activation=tf.keras.activations.get(activation),
            activity_regularizer=activity_regularizer,
            kernel_posterior_fn=kernel_posterior_fn,
            kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
            kernel_prior_fn=kernel_prior_fn,
            kernel_divergence_fn=kernel_divergence_fn,
            bias_posterior_fn=bias_posterior_fn,
            bias_posterior_tensor_fn=bias_posterior_tensor_fn,
            bias_prior_fn=bias_prior_fn,
            bias_divergence_fn=bias_divergence_fn,
            **kwargs)

        self.client_weight = client_weight
        self.delta_function = tf.subtract
        self.apply_delta_function = tf.add
        self.client_variable_dict = {}
        self.client_center_variable_dict = {}
        self.server_variable_dict = {}
    def __init__(
            self,
            units,
            activation=None,
            activity_regularizer=None,
            trainable=True,
            kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(),
            kernel_posterior_tensor_fn=lambda d: d.sample(),
            kernel_prior_fn=lambda dtype, *args: tfd.Normal(  # pylint: disable=g-long-lambda
                loc=dtype.as_numpy_dtype(0.),
                scale=dtype.as_numpy_dtype(1.)),
            kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
            bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(
                is_singular=True),
            bias_posterior_tensor_fn=lambda d: d.sample(),
            bias_prior_fn=None,
            bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
            seed=None,
            **kwargs):
        # pylint: disable=g-doc-args
        """Construct layer.

    Args:
      @{args}
      seed: Python scalar `int` which initializes the random number
        generator. Default value: `None` (i.e., use global seed).
    """
        # pylint: enable=g-doc-args
        super(DenseFlipout, self).__init__(
            units=units,
            activation=activation,
            activity_regularizer=activity_regularizer,
            trainable=trainable,
            kernel_posterior_fn=kernel_posterior_fn,
            kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
            kernel_prior_fn=kernel_prior_fn,
            kernel_divergence_fn=kernel_divergence_fn,
            bias_posterior_fn=bias_posterior_fn,
            bias_posterior_tensor_fn=bias_posterior_tensor_fn,
            bias_prior_fn=bias_prior_fn,
            bias_divergence_fn=bias_divergence_fn,
            **kwargs)
        # Set additional attributes which do not exist in the parent class.
        self.seed = seed
    def __init__(
            self,
            units,
            activation=None,
            activity_regularizer=None,
            trainable=True,
            kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(),
            kernel_posterior_tensor_fn=lambda d: d.sample(),
            kernel_prior_fn=lambda dtype, *args: tfd.Normal(  # pylint: disable=g-long-lambda
                loc=dtype.as_numpy_dtype(0.),
                scale=dtype.as_numpy_dtype(1.)),
            kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
            bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(
                is_singular=True),
            bias_posterior_tensor_fn=lambda d: d.sample(),
            bias_prior_fn=None,
            bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
            seed=None,
            name=None,
            **kwargs):
        # pylint: disable=g-doc-args
        """Construct layer.

    Args:
      @{args}
    """
        # pylint: enable=g-doc-args
        super(DenseFlipout, self).__init__(
            units=units,
            activation=activation,
            activity_regularizer=activity_regularizer,
            trainable=trainable,
            kernel_posterior_fn=kernel_posterior_fn,
            kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
            kernel_prior_fn=kernel_prior_fn,
            kernel_divergence_fn=kernel_divergence_fn,
            bias_posterior_fn=bias_posterior_fn,
            bias_posterior_tensor_fn=bias_posterior_tensor_fn,
            bias_prior_fn=bias_prior_fn,
            bias_divergence_fn=bias_divergence_fn,
            name=name,
            **kwargs)
        self.seed = seed
    def __init__(
            self,
            units,
            activation=None,
            activity_regularizer=None,
            trainable=True,
            kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(),
            kernel_posterior_tensor_fn=lambda d: d.sample(),
            kernel_prior_fn=lambda dtype, *args: tfd.Normal(  # pylint: disable=g-long-lambda
                loc=dtype.as_numpy_dtype(0.),
                scale=dtype.as_numpy_dtype(1.)),
            kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
            bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(
                is_singular=True),  # pylint: disable=line-too-long
            bias_posterior_tensor_fn=lambda d: d.sample(),
            bias_prior_fn=None,
            bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
            name=None,
            **kwargs):
        # pylint: disable=g-doc-args
        """Construct layer.

    Args:
      @{args}
    """
        # pylint: enable=g-doc-args
        super(_DenseVariational,
              self).__init__(trainable=trainable,
                             name=name,
                             activity_regularizer=activity_regularizer,
                             **kwargs)
        self.units = units
        self.activation = activation
        self.input_spec = tf.layers.InputSpec(min_ndim=2)
        self.kernel_posterior_fn = kernel_posterior_fn
        self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn
        self.kernel_prior_fn = kernel_prior_fn
        self.kernel_divergence_fn = kernel_divergence_fn
        self.bias_posterior_fn = bias_posterior_fn
        self.bias_posterior_tensor_fn = bias_posterior_tensor_fn
        self.bias_prior_fn = bias_prior_fn
        self.bias_divergence_fn = bias_divergence_fn
    def __init__(
            self,
            units,
            activation=None,
            activity_regularizer=None,
            trainable=True,
            kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(),
            kernel_posterior_tensor_fn=lambda d: d.sample(),
            kernel_prior_fn=lambda dtype, shape, *dummy_args: tfd.Independent(  # pylint: disable=g-long-lambda
                tfd.Normal(loc=tf.zeros(shape, dtype),
                           scale=dtype.as_numpy_dtype(1.))),
            kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
            bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(
                is_singular=True),
            bias_posterior_tensor_fn=lambda d: d.sample(),
            bias_prior_fn=None,
            bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
            **kwargs):
        # pylint: disable=g-doc-args
        """Construct layer.

    Args:
      @{args}
    """
        # pylint: enable=g-doc-args
        super(DenseLocalReparameterization, self).__init__(
            units=units,
            activation=activation,
            activity_regularizer=activity_regularizer,
            trainable=trainable,
            kernel_posterior_fn=kernel_posterior_fn,
            kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
            kernel_prior_fn=kernel_prior_fn,
            kernel_divergence_fn=kernel_divergence_fn,
            bias_posterior_fn=bias_posterior_fn,
            bias_posterior_tensor_fn=bias_posterior_tensor_fn,
            bias_prior_fn=bias_prior_fn,
            bias_divergence_fn=bias_divergence_fn,
            **kwargs)
    def __init__(
            self,
            units,
            activation=None,
            activity_regularizer=None,
            trainable=True,
            kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(),
            kernel_posterior_tensor_fn=lambda d: d.sample(),
            kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn,
            kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(
                q, p),
            bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(
                is_singular=True),
            bias_posterior_tensor_fn=lambda d: d.sample(),
            bias_prior_fn=None,
            bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
            **kwargs):
        # pylint: disable=g-doc-args
        """Construct layer.

    Args:
      ${args}
    """
        # pylint: enable=g-doc-args
        super(DenseLocalReparameterization, self).__init__(
            units=units,
            activation=activation,
            activity_regularizer=activity_regularizer,
            trainable=trainable,
            kernel_posterior_fn=kernel_posterior_fn,
            kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
            kernel_prior_fn=kernel_prior_fn,
            kernel_divergence_fn=kernel_divergence_fn,
            bias_posterior_fn=bias_posterior_fn,
            bias_posterior_tensor_fn=bias_posterior_tensor_fn,
            bias_prior_fn=bias_prior_fn,
            bias_divergence_fn=bias_divergence_fn,
            **kwargs)
Exemple #14
0
def convnet(inshape,
            numclass,
            activation=tf.nn.relu,
            regularizer=0.0,
            priorstd=1,
            poststd=None,
            isBay=False,
            repeatConv=1):
    priorfn = gen_priordist(std=priorstd)
    if poststd is None:
        postfn = tfp_layers_util.default_mean_field_normal_fn()
    else:
        postfn = gen_postdist(std=poststd)
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Reshape(inshape))

    for _ in range(repeatConv):
        model.add(
            tf.keras.layers.Conv2D(32,
                                   kernel_size=3,
                                   padding="SAME",
                                   activation=activation))
        model.add(
            tf.keras.layers.MaxPool2D(pool_size=[2, 2],
                                      strides=[2, 2],
                                      padding='SAME'))

    model.add(tf.keras.layers.Flatten())

    if isBay:
        model.add(
            tfp.layers.DenseFlipout(numclass,
                                    kernel_prior_fn=priorfn,
                                    kernel_posterior_fn=postfn))
    else:
        model.add(
            tf.keras.layers.Dense(
                numclass,
                kernel_regularizer=tf.keras.regularizers.l2(regularizer)))

    return model
def convnet(inshape,
            numclass,
            activation=tf.nn.relu,
            priorstd=1,
            poststd=None,
            isBay=False):
    priorfn = gen_priordist(std=priorstd)
    if poststd is None:
        postfn = tfp_layers_util.default_mean_field_normal_fn()
    else:
        postfn = gen_postdist(std=poststd)
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Reshape(inshape))
    # if isBay:
    #     layer = tfp.layers.Convolution2DFlipout(
    #         32, kernel_size=3, padding="SAME",
    #         activation=self.activation)
    # else:
    layer = tf.keras.layers.Conv2D(32,
                                   kernel_size=3,
                                   padding="SAME",
                                   activation=activation)
    model.add(layer)
    model.add(
        tf.keras.layers.MaxPool2D(pool_size=[2, 2],
                                  strides=[2, 2],
                                  padding='SAME'))
    model.add(tf.keras.layers.Flatten())
    if isBay:
        model.add(
            tfp.layers.DenseFlipout(numclass,
                                    kernel_prior_fn=priorfn,
                                    kernel_posterior_fn=postfn))
    else:
        model.add(tf.keras.layers.Dense(numclass))

    return model
Exemple #16
0
    def __init__(self,
                 deg=1,
                 output_units=1,
                 use_xbias=True,
                 init_w=None,
                 name=None,
                 activation=None,
                 trainable=True,
                 kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(),
                 kernel_posterior_tensor_fn=lambda d: d.sample(),
                 kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn,
                 kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p)):
        if 'tf' not in keras.__version__:
            raise EnvironmentError("The current implementation of this layer does not allow it to be run with Keras, "
                                   "pleas modify astroNN configure in ~/config.ini key -> tensorflow_keras = tensorflow")
        super().__init__(name=name,
                         units=output_units,
                         trainable=trainable,
                         kernel_posterior_fn=kernel_posterior_fn,
                         kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
                         kernel_prior_fn=kernel_prior_fn,
                         kernel_divergence_fn=kernel_divergence_fn)
        self.input_spec = InputSpec(min_ndim=2)
        self.deg = deg
        self.output_units = output_units
        self.use_bias = use_xbias
        self.activation = activations.get(activation)
        self.kernel_posterior_fn = kernel_posterior_fn
        self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn
        self.kernel_prior_fn = kernel_prior_fn
        self.kernel_divergence_fn = kernel_divergence_fn
        self.init_w = init_w

        if self.init_w is not None and len(self.init_w) != self.deg + 1:
            raise ValueError(f"If you specify initial weight for {self.deg}-deg polynomial, "
                             f"you must provide {self.deg+1} weights")
Exemple #17
0
def get_posterior_fn():
  return tfp_layers_util.default_mean_field_normal_fn(
      loc_initializer=tf1.initializers.he_normal(), 
      untransformed_scale_initializer=tf1.initializers.random_normal(
          mean=-9.0, stddev=0.1)
      )
Exemple #18
0
 def __init__(
     self,
     filters,
     kernel_size,
     strides=(1, 1, 1),
     padding="valid",
     data_format="channels_last",
     dilation_rate=(1, 1, 1),
     activation=None,
     activity_regularizer=None,
     is_mc=tf.constant(False, dtype=tf.bool),
     kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(),
     kernel_posterior_tensor_fn=lambda d: d.sample(),
     kernel_prior_fn=tfp_layers_util.default_multivariate_normal_fn,
     kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
     bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(
         is_singular=True
     ),
     bias_posterior_tensor_fn=lambda d: d.sample(),
     bias_prior_fn=None,
     bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p),
     **kwargs
 ):
     # pylint: disable=g-doc-args
     """Construct layer.
     Args:
       filters: Integer, the dimensionality of the output space (i.e. the number
         of filters in the convolution).
       kernel_size: An integer or tuple/list of 3 integers, specifying the
         depth, height and width of the 3D convolution window.
         Can be a single integer to specify the same value for
         all spatial dimensions.
       strides: An integer or tuple/list of 3 integers,
         specifying the strides of the convolution along the depth,
         height and width.
         Can be a single integer to specify the same value for
         all spatial dimensions.
         Specifying any stride value != 1 is incompatible with specifying
         any `dilation_rate` value != 1.
       padding: One of `"valid"` or `"same"` (case-insensitive).
       data_format: A string, one of `channels_last` (default) or
         `channels_first`. The ordering of the dimensions in the inputs.
         `channels_last` corresponds to inputs with shape `(batch, depth,
         height, width, channels)` while `channels_first` corresponds to inputs
         with shape `(batch, channels, depth, height, width)`.
       dilation_rate: An integer or tuple/list of 3 integers, specifying
         the dilation rate to use for dilated convolution.
         Can be a single integer to specify the same value for
         all spatial dimensions.
         Currently, specifying any `dilation_rate` value != 1 is
         incompatible with specifying any stride value != 1.
       ${args}"""
     super(Conv3DReparameterization, self).__init__(
         rank=3,
         filters=filters,
         kernel_size=kernel_size,
         is_mc=is_mc,
         strides=strides,
         padding=padding,
         data_format=data_format,
         dilation_rate=dilation_rate,
         activation=tf.keras.activations.get(activation),
         activity_regularizer=activity_regularizer,
         kernel_posterior_fn=kernel_posterior_fn,
         kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
         kernel_prior_fn=kernel_prior_fn,
         kernel_divergence_fn=kernel_divergence_fn,
         bias_posterior_fn=bias_posterior_fn,
         bias_posterior_tensor_fn=bias_posterior_tensor_fn,
         bias_prior_fn=bias_prior_fn,
         bias_divergence_fn=bias_divergence_fn,
         **kwargs
     )
def dense_flipout(
        inputs,
        units,
        activation=None,
        activity_regularizer=None,
        trainable=True,
        kernel_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(),
        kernel_posterior_tensor_fn=lambda d: d.sample(),
        kernel_prior_fn=lambda dtype, *args: tfd.Normal(  # pylint: disable=g-long-lambda
            loc=dtype.as_numpy_dtype(0.),
            scale=dtype.as_numpy_dtype(1.)),
        kernel_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
        bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(
            is_singular=True),
        bias_posterior_tensor_fn=lambda d: d.sample(),
        bias_prior_fn=None,
        bias_divergence_fn=lambda q, p, ignore: tfd.kl_divergence(q, p),
        seed=None,
        name=None,
        reuse=None):
    # pylint: disable=g-doc-args
    """Densely-connected layer with Flipout estimator.

  This layer implements the Bayesian variational inference analogue to
  a dense layer by assuming the `kernel` and/or the `bias` are drawn
  from distributions. By default, the layer implements a stochastic
  forward pass via sampling from the kernel and bias posteriors,

  ```none
  kernel, bias ~ posterior
  outputs = activation(matmul(inputs, kernel) + bias)
  ```

  It uses the Flipout estimator [1], which performs a Monte Carlo
  approximation of the distribution integrating over the `kernel` and
  `bias`. Flipout uses roughly twice as many floating point operations
  as the reparameterization estimator but has the advantage of
  significantly lower variance.

  The arguments permit separate specification of the surrogate posterior
  (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias`
  distributions.

  Args:
    inputs: Tensor input.
    @{args}

  Returns:
    output: `Tensor` representing a the affine transformed input under a random
      draw from the surrogate posterior distribution.

  #### Examples

  We illustrate a Bayesian neural network with [variational inference](
  https://en.wikipedia.org/wiki/Variational_Bayesian_methods),
  assuming a dataset of `features` and `labels`.

  ```python
  import tensorflow_probability as tfp

  net = tfp.layers.dense_flipout(
      features, 512, activation=tf.nn.relu)
  logits = tfp.layers.dense_flipout(net, 10)
  neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits(
      labels=labels, logits=logits)
  kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
  loss = neg_log_likelihood + kl
  train_op = tf.train.AdamOptimizer().minimize(loss)
  ```

  It uses the Flipout gradient estimator to minimize the
  Kullback-Leibler divergence up to a constant, also known as the
  negative Evidence Lower Bound. It consists of the sum of two terms:
  the expected negative log-likelihood, which we approximate via
  Monte Carlo; and the KL divergence, which is added via regularizer
  terms which are arguments to the layer.

  [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on
        Mini-Batches."
        Anonymous. OpenReview, 2017.
        https://openreview.net/forum?id=rJnpifWAb
  """
    # pylint: enable=g-doc-args
    layer = DenseFlipout(units,
                         activation=activation,
                         activity_regularizer=activity_regularizer,
                         trainable=trainable,
                         kernel_posterior_fn=kernel_posterior_fn,
                         kernel_posterior_tensor_fn=kernel_posterior_tensor_fn,
                         kernel_prior_fn=kernel_prior_fn,
                         kernel_divergence_fn=kernel_divergence_fn,
                         bias_posterior_fn=bias_posterior_fn,
                         bias_posterior_tensor_fn=bias_posterior_tensor_fn,
                         bias_prior_fn=bias_prior_fn,
                         bias_divergence_fn=bias_divergence_fn,
                         seed=seed,
                         name=name,
                         dtype=inputs.dtype.base_dtype,
                         _scope=name,
                         _reuse=reuse)
    return layer.apply(inputs)
Exemple #20
0
def main(_):
    dim_output = FLAGS.dim_y
    dim_input = FLAGS.dim_im * FLAGS.dim_im * 1

    exp_name = '%s.num_noise-%g.noise-%g.beta-%g.meta_lr-%g.update_lr-%g.trial-%d' % (
        'maml_bbb', FLAGS.num_noise, FLAGS.noise_scale, FLAGS.beta,
        FLAGS.meta_lr, FLAGS.update_lr, FLAGS.trial)
    checkpoint_dir = os.path.join(FLAGS.logdir, exp_name)

    x_train, y_train = pickle.load(
        open(os.path.join(FLAGS.data_dir, FLAGS.data[0]), 'rb'))
    x_val, y_val = pickle.load(
        open(os.path.join(FLAGS.data_dir, FLAGS.data[1]), 'rb'))

    x_train, y_train = np.array(x_train), np.array(y_train)
    y_train = y_train[:, :, -1, None]
    x_val, y_val = np.array(x_val), np.array(y_val)
    y_val = y_val[:, :, -1, None]

    ds_train = tf.data.Dataset.from_generator(
        functools.partial(gen, x_train, y_train),
        (tf.float32, tf.float32, tf.float32, tf.float32),
        (tf.TensorShape(
            [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
         tf.TensorShape(
             [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output]),
         tf.TensorShape(
             [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
         tf.TensorShape(
             [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output])))

    ds_val = tf.data.Dataset.from_generator(
        functools.partial(gen, x_val, y_val),
        (tf.float32, tf.float32, tf.float32, tf.float32),
        (tf.TensorShape(
            [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
         tf.TensorShape(
             [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output]),
         tf.TensorShape(
             [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
         tf.TensorShape(
             [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output])))

    kernel_posterior_fn = tfp_layers_util.default_mean_field_normal_fn(
        untransformed_scale_initializer=tf.compat.v1.initializers.
        random_normal(mean=FLAGS.var, stddev=0.1))

    encoder_w = tf.keras.Sequential([
        tfp.layers.Convolution2DReparameterization(
            filters=32,
            kernel_size=3,
            strides=(2, 2),
            activation='relu',
            padding='SAME',
            kernel_posterior_fn=kernel_posterior_fn),
        tfp.layers.Convolution2DReparameterization(
            filters=48,
            kernel_size=3,
            strides=(2, 2),
            activation='relu',
            padding='SAME',
            kernel_posterior_fn=kernel_posterior_fn),
        MaxPooling2D(pool_size=(2, 2)),
        tfp.layers.Convolution2DReparameterization(
            filters=64,
            kernel_size=3,
            strides=(2, 2),
            activation='relu',
            padding='SAME',
            kernel_posterior_fn=kernel_posterior_fn),
        tf.keras.layers.Flatten(),
        tfp.layers.DenseReparameterization(
            FLAGS.dim_w, kernel_posterior_fn=kernel_posterior_fn),
    ])

    xa, labela, xb, labelb = ds_train.make_one_shot_iterator().get_next()
    xa = tf.reshape(xa, [-1, 128, 128, 1])
    xb = tf.reshape(xb, [-1, 128, 128, 1])
    with tf.variable_scope('encoder'):
        inputa = encoder_w(xa)
    inputa = tf.reshape(
        inputa, [-1, FLAGS.update_batch_size * FLAGS.num_classes, FLAGS.dim_w])
    inputb = encoder_w(xb)
    inputb = tf.reshape(
        inputb, [-1, FLAGS.update_batch_size * FLAGS.num_classes, FLAGS.dim_w])

    input_tensors = {'inputa': inputa,\
                     'inputb': inputb, \
                     'labela': labela, 'labelb': labelb}
    # n_task * n_im_per_task * dim_w
    xa_val, labela_val, xb_val, labelb_val = ds_val.make_one_shot_iterator(
    ).get_next()
    xa_val = tf.reshape(xa_val, [-1, 128, 128, 1])
    xb_val = tf.reshape(xb_val, [-1, 128, 128, 1])

    inputa_val = encoder_w(xa_val)
    inputa_val = tf.reshape(
        inputa_val,
        [-1, FLAGS.update_batch_size * FLAGS.num_classes, FLAGS.dim_w])

    inputb_val = encoder_w(xb_val)
    inputb_val = tf.reshape(
        inputb_val,
        [-1, FLAGS.update_batch_size * FLAGS.num_classes, FLAGS.dim_w])

    metaval_input_tensors = {'inputa': inputa_val,\
                             'inputb': inputb_val, \
                             'labela': labela_val, 'labelb': labelb_val}

    # num_updates = max(self.test_num_updates, FLAGS.num_updates)
    model = MAML(encoder_w, FLAGS.dim_w, dim_output)
    model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    model.construct_model(input_tensors=metaval_input_tensors,
                          prefix='metaval_',
                          test_num_updates=FLAGS.test_num_updates)

    # model.construct_model(input_tensors=input_tensors, prefix='metaval_')
    model.summ_op = tf.summary.merge_all()
    sess = tf.InteractiveSession()

    tf.global_variables_initializer().run()

    if FLAGS.train:
        train(model, sess, checkpoint_dir)
    def __init__(self,
                 units,
                 activation='tanh',
                 recurrent_activation='hard_sigmoid',
                 use_bias=True,
                 kernel_initializer=tf.keras.initializers.VarianceScaling(scale=30.0,
                                                                          mode='fan_avg',
                                                                          distribution='uniform',),
                 recurrent_initializer=tf.keras.initializers.Orthogonal(gain=7.0),
                 bias_initializer='zeros',
                 unit_forget_bias=True,
                 kernel_constraint=None,
                 recurrent_constraint=None,
                 bias_constraint=None,
                 dropout=0.,
                 recurrent_dropout=0.,
                 implementation=1,
                 kernel_posterior_fn=None,
                 kernel_posterior_tensor_fn=(lambda d: d.sample()),
                 recurrent_kernel_posterior_fn=None,
                 recurrent_kernel_posterior_tensor_fn=(lambda d: d.sample()),
                 kernel_prior_fn=None,
                 recurrent_kernel_prior_fn=None,
                 kernel_divergence_fn=(lambda q, p, ignore: tfd.kl_divergence(q, p)),
                 recurrent_kernel_divergence_fn=(lambda q, p, ignore: tfd.kl_divergence(q, p)),
                 bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(
                     is_singular=True),
                 bias_posterior_tensor_fn=(lambda d: d.sample()),
                 bias_prior_fn=None,
                 bias_divergence_fn=(lambda q, p, ignore: tfd.kl_divergence(q, p)),
                 client_weight=1.,
                 **kwargs):

        self.untransformed_scale_initializer = kwargs.pop('untransformed_scale_initializer', None)

        if kernel_posterior_fn is None:
            kernel_posterior_fn = self.renormalize_natural_mean_field_normal_fn
        if kernel_prior_fn is None:
            kernel_prior_fn = self.natural_tensor_multivariate_normal_fn
        if recurrent_kernel_posterior_fn is None:
            recurrent_kernel_posterior_fn = self.renormalize_natural_mean_field_normal_fn
        if recurrent_kernel_prior_fn is None:
            recurrent_kernel_prior_fn = self.natural_tensor_multivariate_normal_fn

        super(LSTMCellVariationalNatural, self).__init__(
            units,
            activation=activation,
            recurrent_activation=recurrent_activation,
            use_bias=use_bias,
            kernel_initializer=kernel_initializer,
            recurrent_initializer=recurrent_initializer,
            bias_initializer=bias_initializer,
            unit_forget_bias=unit_forget_bias,
            kernel_regularizer=None,
            recurrent_regularizer=None,
            bias_regularizer=None,
            kernel_constraint=kernel_constraint,
            recurrent_constraint=recurrent_constraint,
            bias_constraint=bias_constraint,
            dropout=dropout,
            recurrent_dropout=recurrent_dropout,
            implementation=implementation,
            **kwargs)

        self.kernel_posterior_fn = kernel_posterior_fn
        self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn
        self.recurrent_kernel_posterior_fn = recurrent_kernel_posterior_fn
        self.recurrent_kernel_posterior_tensor_fn = recurrent_kernel_posterior_tensor_fn
        self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn
        self.kernel_prior_fn = kernel_prior_fn
        self.recurrent_kernel_prior_fn = recurrent_kernel_prior_fn
        self.kernel_divergence_fn = kernel_divergence_fn
        self.recurrent_kernel_divergence_fn = recurrent_kernel_divergence_fn
        self.bias_posterior_fn = bias_posterior_fn
        self.bias_posterior_tensor_fn = bias_posterior_tensor_fn
        self.bias_prior_fn = bias_prior_fn
        self.bias_divergence_fn = bias_divergence_fn
        self.client_weight = client_weight
        self.delta_function = tf.subtract
        self.apply_delta_function = tf.add
        self.client_variable_dict = {}
        self.client_center_variable_dict = {}
        self.server_variable_dict = {}
def main(_):
    kernel_posterior_fn = tfp_layers_util.default_mean_field_normal_fn(
        untransformed_scale_initializer=tf.compat.v1.initializers.
        random_normal(mean=FLAGS.var, stddev=0.1))
    encoder_w0 = tf.keras.Sequential([
        tfp.layers.Convolution2DReparameterization(
            filters=32,
            kernel_size=3,
            strides=(2, 2),
            activation='relu',
            padding='SAME',
            kernel_posterior_fn=kernel_posterior_fn),
        tfp.layers.Convolution2DReparameterization(
            filters=48,
            kernel_size=3,
            strides=(2, 2),
            activation='relu',
            padding='SAME',
            kernel_posterior_fn=kernel_posterior_fn),
        MaxPooling2D(pool_size=(2, 2)),
        tfp.layers.Convolution2DReparameterization(
            filters=64,
            kernel_size=3,
            strides=(2, 2),
            activation='relu',
            padding='SAME',
            kernel_posterior_fn=kernel_posterior_fn),
        tf.keras.layers.Flatten(),
        tfp.layers.DenseReparameterization(
            FLAGS.dim_w, kernel_posterior_fn=kernel_posterior_fn),
    ])

    decoder0 = tf.keras.Sequential([
        tf.keras.layers.Dense(100, activation=tf.nn.relu),
        tf.keras.layers.Dense(100, activation=tf.nn.relu),
        tf.keras.layers.Dense(FLAGS.dim_y),
    ])

    dim_output = FLAGS.dim_y
    dim_input = FLAGS.dim_im * FLAGS.dim_im * 1

    exp_name = '%s.beta-%g.update_lr-%g.trial-%d' % (
        'np_bbb', FLAGS.beta, FLAGS.update_lr, FLAGS.trial)
    checkpoint_dir = os.path.join(FLAGS.logdir, exp_name)

    x_train, y_train = pickle.load(
        tf.io.gfile.GFile(os.path.join(get_data_dir(), FLAGS.data[0]), 'rb'))
    x_val, y_val = pickle.load(
        tf.io.gfile.GFile(os.path.join(get_data_dir(), FLAGS.data[1]), 'rb'))

    x_train, y_train = np.array(x_train), np.array(y_train)
    y_train = y_train[:, :, -1, None]
    x_val, y_val = np.array(x_val), np.array(y_val)
    y_val = y_val[:, :, -1, None]

    ds_train = tf.data.Dataset.from_generator(
        functools.partial(gen, x_train, y_train),
        (tf.float32, tf.float32, tf.float32, tf.float32),
        (tf.TensorShape(
            [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
         tf.TensorShape(
             [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output]),
         tf.TensorShape(
             [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
         tf.TensorShape(
             [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output])))

    ds_val = tf.data.Dataset.from_generator(
        functools.partial(gen, x_val, y_val),
        (tf.float32, tf.float32, tf.float32, tf.float32),
        (tf.TensorShape(
            [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
         tf.TensorShape(
             [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output]),
         tf.TensorShape(
             [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_input]),
         tf.TensorShape(
             [None, FLAGS.update_batch_size * FLAGS.num_classes, dim_output])))

    inputa, labela, inputb, labelb = ds_train.make_one_shot_iterator(
    ).get_next()

    input_tensors = {'inputa': inputa,\
                     'inputb': inputb,\
                     'labela': labela, 'labelb': labelb}

    inputa_val, labela_val, inputb_val, labelb_val = ds_val.make_one_shot_iterator(
    ).get_next()

    metaval_input_tensors = {'inputa': inputa_val,\
                             'inputb': inputb_val,\
                             'labela': labela_val, 'labelb': labelb_val}

    loss, train_op, facto = construct_model(input_tensors,
                                            encoder_w0,
                                            decoder0,
                                            prefix='metatrain_')
    loss_val = construct_model(metaval_input_tensors,
                               encoder_w0,
                               decoder0,
                               prefix='metaval_')

    ###########

    summ_op = tf.summary.merge_all()
    sess = tf.InteractiveSession()
    summary_writer = tf.summary.FileWriter(checkpoint_dir, sess.graph)
    tf.global_variables_initializer().run()

    PRINT_INTERVAL = 50  # pylint: disable=invalid-name
    SUMMARY_INTERVAL = 5  # pylint: disable=invalid-name
    prelosses, prelosses_val = [], []
    old_time = time.time()
    for itr in range(FLAGS.num_updates):

        feed_dict = {facto: FLAGS.facto}

        if itr % SUMMARY_INTERVAL == 0:
            summary, cost, cost_val = sess.run([summ_op, loss, loss_val],
                                               feed_dict)
            summary_writer.add_summary(summary, itr)
            prelosses.append(cost)  # 0 step loss on training set
            prelosses_val.append(
                cost_val)  # 0 step loss on meta_val training set

        sess.run(train_op, feed_dict)

        if (itr != 0) and itr % PRINT_INTERVAL == 0:
            print('Iteration ' + str(itr) + ': ' + str(np.mean(prelosses)),
                  'time =',
                  time.time() - old_time)
            prelosses = []
            old_time = time.time()
            print('Validation results: ' + str(np.mean(prelosses_val)))
            prelosses_val = []
    def __init__(self,
                 units,
                 activation='tanh',
                 recurrent_activation='hard_sigmoid',
                 use_bias=True,
                 kernel_initializer='glorot_uniform',
                 recurrent_initializer='orthogonal',
                 bias_initializer='zeros',
                 unit_forget_bias=True,
                 kernel_constraint=None,
                 recurrent_constraint=None,
                 bias_constraint=None,
                 dropout=0.,
                 recurrent_dropout=0.,
                 implementation=1,
                 kernel_posterior_fn=None,
                 kernel_posterior_tensor_fn=(lambda d: d.sample()),
                 recurrent_kernel_posterior_fn=None,
                 recurrent_kernel_posterior_tensor_fn=(lambda d: d.sample()),
                 kernel_prior_fn=None,
                 recurrent_kernel_prior_fn=None,
                 kernel_divergence_fn=(lambda q, p, ignore: tfd.kl_divergence(q, p)),
                 recurrent_kernel_divergence_fn=(lambda q, p, ignore: tfd.kl_divergence(q, p)),
                 bias_posterior_fn=tfp_layers_util.default_mean_field_normal_fn(
                     is_singular=True),
                 bias_posterior_tensor_fn=(lambda d: d.sample()),
                 bias_prior_fn=None,
                 bias_divergence_fn=(lambda q, p, ignore: tfd.kl_divergence(q, p)),
                 client_weight=1.,
                 **kwargs):

        self.untransformed_scale_initializer = None
        if 'untransformed_scale_initializer' in kwargs:
            self.untransformed_scale_initializer = \
                kwargs.pop('untransformed_scale_initializer')

        if kernel_posterior_fn is None:
            kernel_posterior_fn = self.renormalize_natural_mean_field_normal_fn
        if kernel_prior_fn is None:
            kernel_prior_fn = self.natural_tensor_multivariate_normal_fn
        if recurrent_kernel_posterior_fn is None:
            recurrent_kernel_posterior_fn = self.renormalize_natural_mean_field_normal_fn
        if recurrent_kernel_prior_fn is None:
            recurrent_kernel_prior_fn = self.natural_tensor_multivariate_normal_fn

        super(LSTMCellReparametrizationNatural, self).__init__(
            units,
            activation=activation,
            recurrent_activation=recurrent_activation,
            use_bias=use_bias,
            kernel_initializer=kernel_initializer,
            recurrent_initializer=recurrent_initializer,
            bias_initializer=bias_initializer,
            unit_forget_bias=unit_forget_bias,
            kernel_regularizer=None,
            recurrent_regularizer=None,
            bias_regularizer=None,
            kernel_constraint=kernel_constraint,
            recurrent_constraint=recurrent_constraint,
            bias_constraint=bias_constraint,
            dropout=dropout,
            recurrent_dropout=recurrent_dropout,
            implementation=implementation,
            **kwargs)

        self.kernel_posterior_fn = kernel_posterior_fn
        self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn
        self.recurrent_kernel_posterior_fn = recurrent_kernel_posterior_fn
        self.recurrent_kernel_posterior_tensor_fn = recurrent_kernel_posterior_tensor_fn
        self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn
        self.kernel_prior_fn = kernel_prior_fn
        self.recurrent_kernel_prior_fn = recurrent_kernel_prior_fn
        self.kernel_divergence_fn = kernel_divergence_fn
        self.recurrent_kernel_divergence_fn = recurrent_kernel_divergence_fn
        self.bias_posterior_fn = bias_posterior_fn
        self.bias_posterior_tensor_fn = bias_posterior_tensor_fn
        self.bias_prior_fn = bias_prior_fn
        self.bias_divergence_fn = bias_divergence_fn
        self.client_weight = client_weight