예제 #1
0
    def __init__(self,
                 network,
                 num_classes,
                 initializer='glorot_uniform',
                 output='logits',
                 dropout_rate=0.1,
                 **kwargs):
        self._self_setattr_tracking = False
        self._config = {
            'network': network,
            'num_classes': num_classes,
            'initializer': initializer,
            'output': output,
        }
        # We want to use the inputs of the passed network as the inputs to this
        # Model. To do this, we need to keep a handle to the network inputs for use
        # when we construct the Model object at the end of init.
        inputs = network.inputs

        # Because we have a copy of inputs to create this Model object, we can
        # invoke the Network object with its own input tensors to start the Model.
        sequence_output, cls_output = network(inputs)

        cls_output = tf.keras.layers.Dropout(rate=dropout_rate)(cls_output)
        self.classifier = networks.Classification(
            input_width=cls_output.shape[-1],
            num_classes=num_classes,
            initializer=initializer,
            output=output,
            name='classification')

        predictions = self.classifier(cls_output)

        # This is an instance variable for ease of access to the underlying task
        # network.
        self.span_labeling = networks.SpanLabeling(
            input_width=sequence_output.shape[-1],
            initializer=initializer,
            output=output,
            name='span_labeling')
        start_logits, end_logits = self.span_labeling(sequence_output)

        # Use identity layers wrapped in lambdas to explicitly name the output
        # tensors. This allows us to use string-keyed dicts in Keras fit/predict/
        # evaluate calls.
        start_logits = tf.keras.layers.Lambda(
            tf.identity, name='start_positions')(start_logits)
        end_logits = tf.keras.layers.Lambda(tf.identity,
                                            name='end_positions')(end_logits)

        logits = [start_logits, end_logits, predictions]

        super(BertUnifiedLabeler, self).__init__(inputs=inputs,
                                                 outputs=logits,
                                                 **kwargs)
예제 #2
0
    def __init__(self,
                 network,
                 initializer='glorot_uniform',
                 output='logits',
                 **kwargs):

        # We want to use the inputs of the passed network as the inputs to this
        # Model. To do this, we need to keep a handle to the network inputs for use
        # when we construct the Model object at the end of init.
        inputs = network.inputs

        # Because we have a copy of inputs to create this Model object, we can
        # invoke the Network object with its own input tensors to start the Model.
        outputs = network(inputs)
        if isinstance(outputs, list):
            sequence_output = outputs[0]
        else:
            sequence_output = outputs['sequence_output']

        # The input network (typically a transformer model) may get outputs from all
        # layers. When this case happens, we retrieve the last layer output.
        if isinstance(sequence_output, list):
            sequence_output = sequence_output[-1]

        # This is an instance variable for ease of access to the underlying task
        # network.
        span_labeling = networks.SpanLabeling(
            input_width=sequence_output.shape[-1],
            initializer=initializer,
            output=output,
            name='span_labeling')
        start_logits, end_logits = span_labeling(sequence_output)

        # Use identity layers wrapped in lambdas to explicitly name the output
        # tensors. This allows us to use string-keyed dicts in Keras fit/predict/
        # evaluate calls.
        start_logits = tf.keras.layers.Lambda(
            tf.identity, name='start_positions')(start_logits)
        end_logits = tf.keras.layers.Lambda(tf.identity,
                                            name='end_positions')(end_logits)

        logits = [start_logits, end_logits]

        # b/164516224
        # Once we've created the network using the Functional API, we call
        # super().__init__ as though we were invoking the Functional API Model
        # constructor, resulting in this object having all the properties of a model
        # created using the Functional API. Once super().__init__ is called, we
        # can assign attributes to `self` - note that all `self` assignments are
        # below this line.
        super(BertSpanLabeler, self).__init__(inputs=inputs,
                                              outputs=logits,
                                              **kwargs)
        self._network = network
        config_dict = {
            'network': network,
            'initializer': initializer,
            'output': output,
        }
        # We are storing the config dict as a namedtuple here to ensure checkpoint
        # compatibility with an earlier version of this model which did not track
        # the config dict attribute. TF does not track immutable attrs which
        # do not contain Trackables, so by creating a config namedtuple instead of
        # a dict we avoid tracking it.
        config_cls = collections.namedtuple('Config', config_dict.keys())
        self._config = config_cls(**config_dict)
        self.span_labeling = span_labeling