def __init__(self,
                 input_dim,
                 output_dim,
                 x_encoder_net_sizes=None,
                 x_y_encoder_net_sizes=None,
                 heteroskedastic_net_sizes=None,
                 global_latent_net_sizes=None,
                 local_latent_net_sizes=None,
                 decoder_net_sizes=None,
                 att_type='multihead',
                 att_heads=8,
                 model_type='fully_connected',
                 activation=tf.nn.relu,
                 output_activation=None,
                 model_path=None,
                 data_uncertainty=True,
                 beta=1.,
                 temperature=1.):
        """Initializes the generalized neural process regressor.

    D below denotes:
    - Context dataset C during decoding phase
    - Target dataset T during encoding phase

    Args:
      input_dim: (int) Dimensionality of covariates x.
      output_dim: (int) Dimensionality of labels y.
      x_encoder_net_sizes: (list of ints) Hidden layer sizes for network
        featurizing x.
      x_y_encoder_net_sizes: (list of ints) Hidden layer sizes for network
        featurizing D.
      heteroskedastic_net_sizes: (list of ints) Hidden layer sizes for network
      that maps x to heteroskedastic variance.
      global_latent_net_sizes: (list of ints) Hidden layer sizes for network
        that maps D to mean and variance of predictive p(z | D).
      local_latent_net_sizes: (list of ints) Hidden layer sizes for network
        that maps xi, z, D to mean and variance of predictive p(zi | z, xi, D).
      decoder_net_sizes: (list of ints) Hidden layer sizes for network that maps
        xi, z, zi, D to mean and variance of predictive p(yi | z, zi, xi, D).
      att_type: (string) Attention type for freeform attention.
      att_heads: (int) Number of heads in case att_type='multihead'.
      model_type: (string) One of 'fully_connected', 'cnp', 'acnp', 'acns',
        'np', 'anp'.
      activation: (callable) Non-linearity used for all neural networks.
      output_activation: (callable) Non-linearity for predictive mean.
      model_path: (string) File path for best early-stopped model.
      data_uncertainty: (boolean) True if data uncertainty is explicit.
      beta: (float) Scaling factor for global kl loss.
      temperature: (float) Inverse scaling factor for temperature.

    Raises:
      ValueError: If model_type is unrecognized.
    """
        if (model_type
                not in ['np', 'anp', 'acns', 'fully_connected', 'cnp',
                        'acnp']):
            raise ValueError('Unrecognized model type: %s' % model_type)

        super(Regressor, self).__init__()
        self._input_dim = input_dim
        self._output_dim = output_dim
        self.model_type = model_type
        self._output_activation = output_activation
        self._data_uncertainty = data_uncertainty
        self.beta = tf.constant(beta)
        self.temperature = temperature

        self._global_latent_layer = None
        self._local_latent_layer = None
        self._decoder_layer = None
        self._dataset_encoding_layer = None
        self._x_encoder = None
        self._heteroskedastic_net = None
        self._homoskedastic_net = None

        contains_global = ['np', 'anp', 'acns', 'fully_connected']
        contains_local = ['acns', 'fully_connected']

        x_dim = input_dim
        if x_encoder_net_sizes is not None:
            self._x_encoder = utils.mlp_block(input_dim, x_encoder_net_sizes,
                                              activation)
            x_dim = x_encoder_net_sizes[-1]

        x_y_net = None
        self_dataset_attention = None
        if x_y_encoder_net_sizes is not None:
            x_y_net = utils.mlp_block(x_dim + output_dim,
                                      x_y_encoder_net_sizes, activation)
            dataset_encoding_dim = x_y_encoder_net_sizes[-1]
        else:
            # Use self-attention.
            dataset_encoding_dim = x_dim + output_dim
            self_dataset_attention = attention.AttentionLayer(
                att_type=att_type, num_heads=att_heads)
            self_dataset_attention.build([x_dim, x_dim])

        self._dataset_encoding_layer = layers.DatasetEncodingLayer(
            x_y_net, self_dataset_attention)
        self._cross_dataset_attention = attention.AttentionLayer(
            att_type=att_type, num_heads=att_heads, scale=self.temperature)
        self._cross_dataset_attention.build([x_dim, dataset_encoding_dim])

        if model_type in contains_global:
            global_latent_net = utils.mlp_block(dataset_encoding_dim,
                                                global_latent_net_sizes,
                                                activation)
            self._global_latent_layer = layers.GlobalLatentLayer(
                global_latent_net)
            global_latent_dim = global_latent_net_sizes[-1] // 2

        if model_type in contains_local:
            local_input_dim = global_latent_dim + dataset_encoding_dim
            local_latent_net = utils.mlp_block(local_input_dim,
                                               local_latent_net_sizes,
                                               activation)
            self._local_latent_layer = layers.LocalLatentLayer(
                local_latent_net)
            local_latent_dim = local_latent_net_sizes[-1] // 2

            separate_prior_net = (model_type != 'fully_connected')
            if separate_prior_net:
                local_latent_net = utils.mlp_block(global_latent_dim,
                                                   local_latent_net_sizes,
                                                   activation)
                self._prior_local_latent_layer = layers.LocalLatentLayer(
                    local_latent_net)
            else:
                self._prior_local_latent_layer = self._local_latent_layer

        if decoder_net_sizes is not None:
            decoder_input_dim = x_dim
            if model_type == 'cnp' or model_type == 'acnp':  # depend on C
                decoder_input_dim += dataset_encoding_dim
            elif model_type == 'np':  # depend on z
                decoder_input_dim += global_latent_dim
            elif model_type == 'anp':  # depend on z, C
                decoder_input_dim += dataset_encoding_dim + global_latent_dim
            elif model_type == 'acns':
                decoder_input_dim += dataset_encoding_dim + local_latent_dim
            elif model_type == 'fully_connected':
                decoder_input_dim += (dataset_encoding_dim +
                                      global_latent_dim + local_latent_dim)
            decoder_net = utils.mlp_block(decoder_input_dim, decoder_net_sizes,
                                          activation)
            self._decoder_layer = layers.DecoderLayer(decoder_net, model_type,
                                                      output_activation)

        if data_uncertainty:
            if heteroskedastic_net_sizes is not None:
                self._heteroskedastic_net = utils.mlp_block(
                    x_dim, heteroskedastic_net_sizes, activation)
            else:
                self._homoskedastic_net = layers.DataNoise()
                self._homoskedastic_net.build(None)

        if model_path:
            self.load_weights(model_path)
Beispiel #2
0
  def __init__(self,
               input_dim,
               output_dim=1,
               x_encoder_sizes=None,
               x_y_encoder_sizes=None,
               heteroskedastic_net_sizes=None,
               global_latent_net_sizes=None,
               local_latent_net_sizes=None,
               att_type='multihead',
               att_heads=8,
               uncertainty_type='attentive_freeform',
               mean_att_type=attention.laplace_attention,
               scale_att_type_1=attention.squared_exponential_attention,
               scale_att_type_2=attention.squared_exponential_attention,
               activation=tf.nn.relu,
               output_activation=None,
               model_path=None,
               data_uncertainty=True,
               local_variational=True):
    """Initializes the structured neural process regressor.

    D below denotes:
    - Context dataset C during decoding phase
    - Target dataset T during encoding phase

    Args:
      input_dim: (int) Dimensionality of covariates x.
      output_dim: (int) Dimensionality of labels y.
      x_encoder_sizes: (list of ints) Hidden layer sizes for featurizing x.
      x_y_encoder_sizes: (list of ints) Hidden layer sizes for featurizing C/D.
      heteroskedastic_net_sizes: (list of ints) Hidden layer sizes for network
      that maps x to heteroskedastic variance.
      global_latent_net_sizes: (list of ints) Hidden layer sizes for network
        that maps D to mean and variance of predictive p(z | D).
      local_latent_net_sizes: (list of ints) Hidden layer sizes for network
        that maps xi, z, D to mean and variance of predictive p(z_i | z, xi, D).
      att_type: (string) Attention type for freeform attention.
      att_heads: (int) Number of heads in case att_type='multihead'.
      uncertainty_type: (string) One of 'attentive_gp', 'attentive_freeform'.
        Default is 'attentive_freeform' which does not impose structure on
        posterior mean, std.
      mean_att_type: (call) Attention for mean of predictive p(zi | z, x, D).
      scale_att_type_1: (call) Attention for std of predictive p(zi | z, x, D).
      scale_att_type_2: (call) Attention for std of predictive p(zi | z, x, D).
      activation: (callable) Non-linearity used for all neural networks.
      output_activation: (callable) Non-linearity for predictive mean.
      model_path: (string) File path for best early-stopped model.
      data_uncertainty: (boolean) True if data uncertainty is explicit.
      local_variational: (boolean) True if VI performed on local latents.
    """
    super(Regressor, self).__init__()
    self._input_dim = input_dim
    self._output_dim = output_dim
    self._uncertainty_type = uncertainty_type
    self._output_activation = output_activation
    self._data_uncertainty = data_uncertainty
    self.local_variational = local_variational
    self._global_latent_layer = None
    self._local_latent_layer = None
    self._dataset_encoding_layer = None
    self._x_encoder = None
    self._heteroskedastic_net = None
    self._homoskedastic_net = None

    x_dim = input_dim
    if x_encoder_sizes is not None:
      self._x_encoder = utils.mlp_block(
          input_dim,
          x_encoder_sizes,
          activation)
      x_dim = x_encoder_sizes[-1]

    x_y_net = None
    self_dataset_attention = None
    if x_y_encoder_sizes is not None:
      x_y_net = utils.mlp_block(
          x_dim + output_dim,
          x_y_encoder_sizes,
          activation)
      dataset_encoding_dim = x_y_encoder_sizes[-1]
    else:
      # Use self-attention.
      dataset_encoding_dim = x_dim + output_dim
      self_dataset_attention = attention.AttentionLayer(
          att_type=att_type, num_heads=att_heads)
      self_dataset_attention.build([x_dim, x_dim])

    self._dataset_encoding_layer = layers.DatasetEncodingLayer(
        x_y_net,
        self_dataset_attention)
    self._cross_dataset_attention = attention.AttentionLayer(
        att_type=att_type, num_heads=att_heads)
    self._cross_dataset_attention.build([x_dim, dataset_encoding_dim])

    local_latent_dim = x_dim
    if global_latent_net_sizes is not None:
      global_latent_net = utils.mlp_block(
          dataset_encoding_dim,
          global_latent_net_sizes,
          activation)
      self._global_latent_layer = layers.GlobalLatentLayer(global_latent_net)
      local_latent_dim += global_latent_net_sizes[-1]//2

    if local_latent_net_sizes is not None:
      # Freeform uncertainty directly attends to dataset encoding.
      if uncertainty_type == 'attentive_freeform':
        local_latent_dim += dataset_encoding_dim

      local_latent_net = utils.mlp_block(
          local_latent_dim,
          local_latent_net_sizes,
          activation)
      self._local_latent_layer = layers.SNPLocalLatentLayer(
          local_latent_net,
          uncertainty_type,
          mean_att_type,
          scale_att_type_1,
          scale_att_type_2,
          output_activation)

    if data_uncertainty:
      if heteroskedastic_net_sizes is not None:
        self._heteroskedastic_net = utils.mlp_block(
            x_dim,
            heteroskedastic_net_sizes,
            activation)
      else:
        self._homoskedastic_net = layers.DataNoise()
        self._homoskedastic_net.build(None)

    if model_path:
      self.load_weights(model_path)