예제 #1
0
 def config(self):
     return {
         "channels":
         self.channels,
         "attn_heads":
         self.attn_heads,
         "concat_heads":
         self.concat_heads,
         "dropout_rate":
         self.dropout_rate,
         "return_attn_coef":
         self.return_attn_coef,
         "attn_kernel_initializer":
         initializers.serialize(self.attn_kernel_initializer),
         "attn_kernel_regularizer":
         regularizers.serialize(self.attn_kernel_regularizer),
         "attn_kernel_constraint":
         constraints.serialize(self.attn_kernel_constraint),
     }
예제 #2
0
 def get_config(self):
   config = {
       'units':
           self.units,
       'activation':
           activations.serialize(self.activation),
       'use_bias':
           self.use_bias,
       'kernel_initializer':
           initializers.serialize(self.kernel_initializer),
       'recurrent_initializer':
           initializers.serialize(self.recurrent_initializer),
       'bias_initializer':
           initializers.serialize(self.bias_initializer),
       'kernel_regularizer':
           regularizers.serialize(self.kernel_regularizer),
       'recurrent_regularizer':
           regularizers.serialize(self.recurrent_regularizer),
       'bias_regularizer':
           regularizers.serialize(self.bias_regularizer),
       'activity_regularizer':
           regularizers.serialize(self.activity_regularizer),
       'kernel_constraint':
           constraints.serialize(self.kernel_constraint),
       'recurrent_constraint':
           constraints.serialize(self.recurrent_constraint),
       'bias_constraint':
           constraints.serialize(self.bias_constraint),
       "kernel_quantizer":
           constraints.serialize(self.kernel_quantizer_internal),
       "recurrent_quantizer":
           constraints.serialize(self.recurrent_quantizer_internal),
       "bias_quantizer":
           constraints.serialize(self.bias_quantizer_internal),
       "state_quantizer":
           constraints.serialize(self.state_quantizer_internal),
       'dropout':
           self.dropout,
       'recurrent_dropout':
           self.recurrent_dropout
   }
   base_config = super(QSimpleRNN, self).get_config()
   del base_config['cell']
   return dict(list(base_config.items()) + list(config.items()))
예제 #3
0
 def get_config(self):
     config = {
         'input_dims':
         self.input_dims,
         'output_dims':
         self.output_dims,
         'dropout_rate':
         self.dropout_rate,
         'embeddings_initializer':
         initializers.serialize(self.embeddings_initializer),
         'embeddings_regularizer':
         regularizers.serialize(self.embeddings_regularizer),
         'activity_regularizer':
         regularizers.serialize(self.activity_regularizer),
         'embeddings_constraint':
         constraints.serialize(self.embeddings_constraint),
         'mask_zero':
         self.mask_zero,
     }
     base_config = super(MultiColumnEmbedding, self).get_config()
     return dict(list(base_config.items()) + list(config.items()))
예제 #4
0
    def get_config(self) -> dict:
        """
        Obtain a key-value representation of the layer config.

        Returns:
            A dict holding the configuration of the layer.
        """
        config = dict(
            num_sums=self.num_sums,
            accumulator_initializer=initializers.serialize(
                self.accumulator_initializer
            ),
            logspace_accumulators=self.logspace_accumulators,
            accumulator_regularizer=regularizers.serialize(
                self.accumulator_regularizer
            ),
            linear_accumulator_constraint=constraints.serialize(
                self.linear_accumulator_constraint
            ),
        )
        base_config = super(DenseSum, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
예제 #5
0
 def get_config(self):
     config = {
         'axis':
         self.axis,
         'momentum':
         self.momentum,
         'epsilon':
         self.epsilon,
         'center':
         self.center,
         'scale':
         self.scale,
         'beta_quantizer':
         constraints.serialize(self.beta_quantizer_internal),
         'gamma_quantizer':
         constraints.serialize(self.gamma_quantizer_internal),
         'mean_quantizer':
         constraints.serialize(self.mean_quantizer_internal),
         'variance_quantizer':
         constraints.serialize(self.variance_quantizer_internal),
         'beta_initializer':
         initializers.serialize(self.beta_initializer),
         'gamma_initializer':
         initializers.serialize(self.gamma_initializer),
         'moving_mean_initializer':
         initializers.serialize(self.moving_mean_initializer),
         'moving_variance_initializer':
         initializers.serialize(self.moving_variance_initializer),
         'inverse_quantizer':
         initializers.serialize(self.inverse_quantizer_internal),
         'beta_regularizer':
         regularizers.serialize(self.beta_regularizer),
         'gamma_regularizer':
         regularizers.serialize(self.gamma_regularizer),
         'beta_constraint':
         constraints.serialize(self.beta_constraint),
         'gamma_constraint':
         constraints.serialize(self.gamma_constraint),
         'beta_range':
         self.beta_range,
         'gamma_range':
         self.gamma_range,
     }
     base_config = super(QBatchNormalization, self).get_config()
     return dict(list(base_config.items()) + list(config.items()))