def __init__(self, network, num_classes, initializer='glorot_uniform', output='logits', dropout_rate=0.1, **kwargs): self._self_setattr_tracking = False self._config = { 'network': network, 'num_classes': num_classes, 'initializer': initializer, 'output': output, } # We want to use the inputs of the passed network as the inputs to this # Model. To do this, we need to keep a handle to the network inputs for use # when we construct the Model object at the end of init. inputs = network.inputs # Because we have a copy of inputs to create this Model object, we can # invoke the Network object with its own input tensors to start the Model. sequence_output, cls_output = network(inputs) cls_output = tf.keras.layers.Dropout(rate=dropout_rate)(cls_output) self.classifier = networks.Classification( input_width=cls_output.shape[-1], num_classes=num_classes, initializer=initializer, output=output, name='classification') predictions = self.classifier(cls_output) # This is an instance variable for ease of access to the underlying task # network. self.span_labeling = networks.SpanLabeling( input_width=sequence_output.shape[-1], initializer=initializer, output=output, name='span_labeling') start_logits, end_logits = self.span_labeling(sequence_output) # Use identity layers wrapped in lambdas to explicitly name the output # tensors. This allows us to use string-keyed dicts in Keras fit/predict/ # evaluate calls. start_logits = tf.keras.layers.Lambda( tf.identity, name='start_positions')(start_logits) end_logits = tf.keras.layers.Lambda(tf.identity, name='end_positions')(end_logits) logits = [start_logits, end_logits, predictions] super(BertUnifiedLabeler, self).__init__(inputs=inputs, outputs=logits, **kwargs)
def __init__(self, network, initializer='glorot_uniform', output='logits', **kwargs): # We want to use the inputs of the passed network as the inputs to this # Model. To do this, we need to keep a handle to the network inputs for use # when we construct the Model object at the end of init. inputs = network.inputs # Because we have a copy of inputs to create this Model object, we can # invoke the Network object with its own input tensors to start the Model. outputs = network(inputs) if isinstance(outputs, list): sequence_output = outputs[0] else: sequence_output = outputs['sequence_output'] # The input network (typically a transformer model) may get outputs from all # layers. When this case happens, we retrieve the last layer output. if isinstance(sequence_output, list): sequence_output = sequence_output[-1] # This is an instance variable for ease of access to the underlying task # network. span_labeling = networks.SpanLabeling( input_width=sequence_output.shape[-1], initializer=initializer, output=output, name='span_labeling') start_logits, end_logits = span_labeling(sequence_output) # Use identity layers wrapped in lambdas to explicitly name the output # tensors. This allows us to use string-keyed dicts in Keras fit/predict/ # evaluate calls. start_logits = tf.keras.layers.Lambda( tf.identity, name='start_positions')(start_logits) end_logits = tf.keras.layers.Lambda(tf.identity, name='end_positions')(end_logits) logits = [start_logits, end_logits] # b/164516224 # Once we've created the network using the Functional API, we call # super().__init__ as though we were invoking the Functional API Model # constructor, resulting in this object having all the properties of a model # created using the Functional API. Once super().__init__ is called, we # can assign attributes to `self` - note that all `self` assignments are # below this line. super(BertSpanLabeler, self).__init__(inputs=inputs, outputs=logits, **kwargs) self._network = network config_dict = { 'network': network, 'initializer': initializer, 'output': output, } # We are storing the config dict as a namedtuple here to ensure checkpoint # compatibility with an earlier version of this model which did not track # the config dict attribute. TF does not track immutable attrs which # do not contain Trackables, so by creating a config namedtuple instead of # a dict we avoid tracking it. config_cls = collections.namedtuple('Config', config_dict.keys()) self._config = config_cls(**config_dict) self.span_labeling = span_labeling