Пример #1
0
 def __init__(self, lr=1e-4, dropout_rate=0.2, units=300, beam_width=12, vocab_size=9000):
     super().__init__()
     self.embedding = tf.Variable(np.load('../data/embedding.npy'),
                                  dtype=tf.float32,
                                  name='pretrained_embedding',
                                  trainable=False)
     self.encoder = Encoder(units=units, dropout_rate=dropout_rate)
     self.attention_mechanism = BahdanauAttention(units=units)
     self.decoder_cell = AttentionWrapper(
         GRUCell(units),
         self.attention_mechanism,
         attention_layer_size=units)
     self.projected_layer = ProjectedLayer(self.embed.embedding)
     self.sampler = tfa.seq2seq.sampler.TrainingSampler()
     self.decoder = BasicDecoder(
         self.decoder_cell,
         self.sampler,
         output_layer=self.projected_layer)
     self.beam_search = BeamSearchDecoder(
         self.decoder_cell,
         beam_width=beam_width,
         embedding_fn=lambda x: tf.nn.embedding_lookup(self.embedding, x),
         output_layer=self.projected_layer)
     self.vocab_size = vocab_size
     self.optimizer = Adam(lr)
     self.accuracy = tf.keras.metrics.Accuracy()
     self.mean = tf.keras.metrics.Mean()
     self.decay_lr = tf.optimizers.schedules.ExponentialDecay(lr, 1000, 0.95)
     self.logger = logging.getLogger('tensorflow')
     self.logger.setLevel(logging.INFO)
Пример #2
0
    def __init__(
            self,
            num_classes,
            cell_type='rnn',
            state_size=256,
            embedding_size=64,
            beam_width=1,
            num_layers=1,
            attention=None,
            tied_embeddings=None,
            is_timeseries=False,
            max_sequence_length=0,
            use_bias=True,
            weights_initializer='glorot_uniform',
            bias_initializer='zeros',
            weights_regularizer=None,
            bias_regularizer=None,
            activity_regularizer=None,
            reduce_input='sum',
            **kwargs
    ):
        super().__init__()
        logger.debug(' {}'.format(self.name))

        self.cell_type = cell_type
        self.state_size = state_size
        self.embedding_size = embedding_size
        self.beam_width = beam_width
        self.num_layers = num_layers
        self.attention = attention
        self.attention_mechanism = None
        self.tied_embeddings = tied_embeddings
        self.is_timeseries = is_timeseries
        self.num_classes = num_classes
        self.max_sequence_length = max_sequence_length
        self.state_size = state_size
        self.attention_mechanism = None

        self.reduce_input = reduce_input if reduce_input else 'sum'
        self.reduce_sequence = SequenceReducer(reduce_mode=self.reduce_input)

        if is_timeseries:
            self.vocab_size = 1
        else:
            self.vocab_size = self.num_classes

        self.GO_SYMBOL = self.vocab_size
        self.END_SYMBOL = 0

        logger.debug('  project input Dense')
        self.project = Dense(
            state_size,
            use_bias=use_bias,
            kernel_initializer=weights_initializer,
            bias_initializer=bias_initializer,
            kernel_regularizer=weights_regularizer,
            bias_regularizer=bias_regularizer,
            activity_regularizer=activity_regularizer
        )

        logger.debug('  Embedding')
        self.decoder_embedding = Embedding(
            input_dim=self.num_classes + 1,  # account for GO_SYMBOL
            output_dim=embedding_size,
            embeddings_initializer=weights_initializer,
            embeddings_regularizer=weights_regularizer,
            activity_regularizer=activity_regularizer
        )
        logger.debug('  project output Dense')
        self.dense_layer = Dense(
            num_classes,
            use_bias=use_bias,
            kernel_initializer=weights_initializer,
            bias_initializer=bias_initializer,
            kernel_regularizer=weights_regularizer,
            bias_regularizer=bias_regularizer,
            activity_regularizer=activity_regularizer
        )
        rnn_cell = get_from_registry(cell_type, rnn_layers_registry)
        rnn_cells = [rnn_cell(state_size) for _ in range(num_layers)]
        self.decoder_rnncell = StackedRNNCells(rnn_cells)
        logger.debug('  {}'.format(self.decoder_rnncell))

        # Sampler
        self.sampler = tfa.seq2seq.sampler.TrainingSampler()

        logger.debug('setting up attention for', attention)
        if attention is not None:
            if attention == 'luong':
                self.attention_mechanism = LuongAttention(units=state_size)
            elif attention == 'bahdanau':
                self.attention_mechanism = BahdanauAttention(units=state_size)
            logger.debug('  {}'.format(self.attention_mechanism))
            self.decoder_rnncell = AttentionWrapper(
                self.decoder_rnncell,
                [self.attention_mechanism] * num_layers,
                attention_layer_size=[state_size] * num_layers
            )
            logger.debug('  {}'.format(self.decoder_rnncell))