def __init__( self, network: Union[tf.keras.layers.Layer, tf.keras.Model], start_n_top: int, end_n_top: int, dropout_rate: float, span_labeling_activation: tf.keras.initializers. Initializer = 'tanh', initializer: tf.keras.initializers.Initializer = 'glorot_uniform', **kwargs): super().__init__(**kwargs) self._config = { 'network': network, 'start_n_top': start_n_top, 'end_n_top': end_n_top, 'dropout_rate': dropout_rate, 'span_labeling_activation': span_labeling_activation, 'initializer': initializer, } self._network = network self._initializer = initializer self._start_n_top = start_n_top self._end_n_top = end_n_top self._dropout_rate = dropout_rate self._activation = span_labeling_activation self.span_labeling = networks.XLNetSpanLabeling( input_width=network.get_config()['inner_size'], start_n_top=self._start_n_top, end_n_top=self._end_n_top, activation=self._activation, dropout_rate=self._dropout_rate, initializer=self._initializer)
def __init__( self, network: Union[tf.keras.layers.Layer, tf.keras.Model], start_n_top: int = 5, end_n_top: int = 5, dropout_rate: float = 0.1, span_labeling_activation: tf.keras.initializers. Initializer = 'tanh', initializer: tf.keras.initializers.Initializer = 'glorot_uniform', **kwargs): super().__init__(**kwargs) self._config = { 'network': network, 'start_n_top': start_n_top, 'end_n_top': end_n_top, 'dropout_rate': dropout_rate, 'span_labeling_activation': span_labeling_activation, 'initializer': initializer, } network_config = network.get_config() try: input_width = network_config['inner_size'] self._xlnet_base = True except KeyError: # BertEncoder uses 'intermediate_size' due to legacy naming. input_width = network_config['intermediate_size'] self._xlnet_base = False self._network = network self._initializer = initializer self._start_n_top = start_n_top self._end_n_top = end_n_top self._dropout_rate = dropout_rate self._activation = span_labeling_activation self.span_labeling = networks.XLNetSpanLabeling( input_width=input_width, start_n_top=self._start_n_top, end_n_top=self._end_n_top, activation=self._activation, dropout_rate=self._dropout_rate, initializer=self._initializer)