def __init__(self, input_dim, output_dim=1, x_encoder_sizes=None, x_y_encoder_sizes=None, heteroskedastic_net_sizes=None, global_latent_net_sizes=None, local_latent_net_sizes=None, att_type='multihead', att_heads=8, uncertainty_type='attentive_freeform', mean_att_type=attention.laplace_attention, scale_att_type_1=attention.squared_exponential_attention, scale_att_type_2=attention.squared_exponential_attention, activation=tf.nn.relu, output_activation=None, model_path=None, data_uncertainty=True, local_variational=True): """Initializes the structured neural process regressor. D below denotes: - Context dataset C during decoding phase - Target dataset T during encoding phase Args: input_dim: (int) Dimensionality of covariates x. output_dim: (int) Dimensionality of labels y. x_encoder_sizes: (list of ints) Hidden layer sizes for featurizing x. x_y_encoder_sizes: (list of ints) Hidden layer sizes for featurizing C/D. heteroskedastic_net_sizes: (list of ints) Hidden layer sizes for network that maps x to heteroskedastic variance. global_latent_net_sizes: (list of ints) Hidden layer sizes for network that maps D to mean and variance of predictive p(z | D). local_latent_net_sizes: (list of ints) Hidden layer sizes for network that maps xi, z, D to mean and variance of predictive p(z_i | z, xi, D). att_type: (string) Attention type for freeform attention. att_heads: (int) Number of heads in case att_type='multihead'. uncertainty_type: (string) One of 'attentive_gp', 'attentive_freeform'. Default is 'attentive_freeform' which does not impose structure on posterior mean, std. mean_att_type: (call) Attention for mean of predictive p(zi | z, x, D). scale_att_type_1: (call) Attention for std of predictive p(zi | z, x, D). scale_att_type_2: (call) Attention for std of predictive p(zi | z, x, D). activation: (callable) Non-linearity used for all neural networks. output_activation: (callable) Non-linearity for predictive mean. model_path: (string) File path for best early-stopped model. data_uncertainty: (boolean) True if data uncertainty is explicit. local_variational: (boolean) True if VI performed on local latents. """ super(Regressor, self).__init__() self._input_dim = input_dim self._output_dim = output_dim self._uncertainty_type = uncertainty_type self._output_activation = output_activation self._data_uncertainty = data_uncertainty self.local_variational = local_variational self._global_latent_layer = None self._local_latent_layer = None self._dataset_encoding_layer = None self._x_encoder = None self._heteroskedastic_net = None self._homoskedastic_net = None x_dim = input_dim if x_encoder_sizes is not None: self._x_encoder = utils.mlp_block(input_dim, x_encoder_sizes, activation) x_dim = x_encoder_sizes[-1] x_y_net = None self_dataset_attention = None if x_y_encoder_sizes is not None: x_y_net = utils.mlp_block(x_dim + output_dim, x_y_encoder_sizes, activation) dataset_encoding_dim = x_y_encoder_sizes[-1] else: # Use self-attention. dataset_encoding_dim = x_dim + output_dim self_dataset_attention = attention.AttentionLayer( att_type=att_type, num_heads=att_heads) self_dataset_attention.build([x_dim, x_dim]) self._dataset_encoding_layer = layers.DatasetEncodingLayer( x_y_net, self_dataset_attention) self._cross_dataset_attention = attention.AttentionLayer( att_type=att_type, num_heads=att_heads) self._cross_dataset_attention.build([x_dim, dataset_encoding_dim]) local_latent_dim = x_dim if global_latent_net_sizes is not None: global_latent_net = utils.mlp_block(dataset_encoding_dim, global_latent_net_sizes, activation) self._global_latent_layer = layers.GlobalLatentLayer( global_latent_net) local_latent_dim += global_latent_net_sizes[-1] // 2 if local_latent_net_sizes is not None: # Freeform uncertainty directly attends to dataset encoding. if uncertainty_type == 'attentive_freeform': local_latent_dim += dataset_encoding_dim local_latent_net = utils.mlp_block(local_latent_dim, local_latent_net_sizes, activation) self._local_latent_layer = layers.SNPLocalLatentLayer( local_latent_net, uncertainty_type, mean_att_type, scale_att_type_1, scale_att_type_2, output_activation) if data_uncertainty: if heteroskedastic_net_sizes is not None: self._heteroskedastic_net = utils.mlp_block( x_dim, heteroskedastic_net_sizes, activation) else: self._homoskedastic_net = layers.DataNoise() self._homoskedastic_net.build(None) if model_path: self.load_weights(model_path)
def __init__(self, input_dim, output_dim, x_encoder_net_sizes=None, x_y_encoder_net_sizes=None, heteroskedastic_net_sizes=None, global_latent_net_sizes=None, local_latent_net_sizes=None, decoder_net_sizes=None, att_type='multihead', att_heads=8, model_type='fully_connected', activation=tf.nn.relu, output_activation=None, model_path=None, data_uncertainty=True, beta=1., temperature=1.): """Initializes the generalized neural process regressor. D below denotes: - Context dataset C during decoding phase - Target dataset T during encoding phase Args: input_dim: (int) Dimensionality of covariates x. output_dim: (int) Dimensionality of labels y. x_encoder_net_sizes: (list of ints) Hidden layer sizes for network featurizing x. x_y_encoder_net_sizes: (list of ints) Hidden layer sizes for network featurizing D. heteroskedastic_net_sizes: (list of ints) Hidden layer sizes for network that maps x to heteroskedastic variance. global_latent_net_sizes: (list of ints) Hidden layer sizes for network that maps D to mean and variance of predictive p(z | D). local_latent_net_sizes: (list of ints) Hidden layer sizes for network that maps xi, z, D to mean and variance of predictive p(zi | z, xi, D). decoder_net_sizes: (list of ints) Hidden layer sizes for network that maps xi, z, zi, D to mean and variance of predictive p(yi | z, zi, xi, D). att_type: (string) Attention type for freeform attention. att_heads: (int) Number of heads in case att_type='multihead'. model_type: (string) One of 'fully_connected', 'cnp', 'acnp', 'acns', 'np', 'anp'. activation: (callable) Non-linearity used for all neural networks. output_activation: (callable) Non-linearity for predictive mean. model_path: (string) File path for best early-stopped model. data_uncertainty: (boolean) True if data uncertainty is explicit. beta: (float) Scaling factor for global kl loss. temperature: (float) Inverse scaling factor for temperature. Raises: ValueError: If model_type is unrecognized. """ if (model_type not in ['np', 'anp', 'acns', 'fully_connected', 'cnp', 'acnp']): raise ValueError('Unrecognized model type: %s' % model_type) super(Regressor, self).__init__() self._input_dim = input_dim self._output_dim = output_dim self.model_type = model_type self._output_activation = output_activation self._data_uncertainty = data_uncertainty self.beta = tf.constant(beta) self.temperature = temperature self._global_latent_layer = None self._local_latent_layer = None self._decoder_layer = None self._dataset_encoding_layer = None self._x_encoder = None self._heteroskedastic_net = None self._homoskedastic_net = None contains_global = ['np', 'anp', 'acns', 'fully_connected'] contains_local = ['acns', 'fully_connected'] x_dim = input_dim if x_encoder_net_sizes is not None: self._x_encoder = utils.mlp_block(input_dim, x_encoder_net_sizes, activation) x_dim = x_encoder_net_sizes[-1] x_y_net = None self_dataset_attention = None if x_y_encoder_net_sizes is not None: x_y_net = utils.mlp_block(x_dim + output_dim, x_y_encoder_net_sizes, activation) dataset_encoding_dim = x_y_encoder_net_sizes[-1] else: # Use self-attention. dataset_encoding_dim = x_dim + output_dim self_dataset_attention = attention.AttentionLayer( att_type=att_type, num_heads=att_heads) self_dataset_attention.build([x_dim, x_dim]) self._dataset_encoding_layer = layers.DatasetEncodingLayer( x_y_net, self_dataset_attention) self._cross_dataset_attention = attention.AttentionLayer( att_type=att_type, num_heads=att_heads, scale=self.temperature) self._cross_dataset_attention.build([x_dim, dataset_encoding_dim]) if model_type in contains_global: global_latent_net = utils.mlp_block(dataset_encoding_dim, global_latent_net_sizes, activation) self._global_latent_layer = layers.GlobalLatentLayer( global_latent_net) global_latent_dim = global_latent_net_sizes[-1] // 2 if model_type in contains_local: local_input_dim = global_latent_dim + dataset_encoding_dim local_latent_net = utils.mlp_block(local_input_dim, local_latent_net_sizes, activation) self._local_latent_layer = layers.LocalLatentLayer( local_latent_net) local_latent_dim = local_latent_net_sizes[-1] // 2 separate_prior_net = (model_type != 'fully_connected') if separate_prior_net: local_latent_net = utils.mlp_block(global_latent_dim, local_latent_net_sizes, activation) self._prior_local_latent_layer = layers.LocalLatentLayer( local_latent_net) else: self._prior_local_latent_layer = self._local_latent_layer if decoder_net_sizes is not None: decoder_input_dim = x_dim if model_type == 'cnp' or model_type == 'acnp': # depend on C decoder_input_dim += dataset_encoding_dim elif model_type == 'np': # depend on z decoder_input_dim += global_latent_dim elif model_type == 'anp': # depend on z, C decoder_input_dim += dataset_encoding_dim + global_latent_dim elif model_type == 'acns': decoder_input_dim += dataset_encoding_dim + local_latent_dim elif model_type == 'fully_connected': decoder_input_dim += (dataset_encoding_dim + global_latent_dim + local_latent_dim) decoder_net = utils.mlp_block(decoder_input_dim, decoder_net_sizes, activation) self._decoder_layer = layers.DecoderLayer(decoder_net, model_type, output_activation) if data_uncertainty: if heteroskedastic_net_sizes is not None: self._heteroskedastic_net = utils.mlp_block( x_dim, heteroskedastic_net_sizes, activation) else: self._homoskedastic_net = layers.DataNoise() self._homoskedastic_net.build(None) if model_path: self.load_weights(model_path)