示例#1
0
    def build_input(
            self,
            regularizer,
            dropout_rate,
            is_training=False,
            **kwargs
    ):
        placeholder = self._get_input_placeholder()

        feature_representation = fc_layer(
            tf.expand_dims(tf.cast(placeholder, tf.float32), 1),
            1,
            1,
            activation=None,
            norm=self.norm,
            dropout=self.dropout,
            dropout_rate=dropout_rate,
            regularizer=regularizer,
            initializer='ones'
        )

        logger.debug('  feature_representation: {0}'.format(
            feature_representation))

        feature_representation = {'type': self.name,
                                  'representation': feature_representation,
                                  'size': 1,
                                  'placeholder': placeholder}

        return feature_representation
示例#2
0
def recurrent_decoder(encoder_outputs,
                      targets,
                      max_sequence_length,
                      vocab_size,
                      cell_type='rnn',
                      state_size=256,
                      embedding_size=50,
                      num_layers=1,
                      attention_mechanism=None,
                      beam_width=1,
                      projection=True,
                      tied_target_embeddings=True,
                      embeddings=None,
                      initializer=None,
                      regularizer=None,
                      is_timeseries=False):
    with tf.variable_scope('rnn_decoder',
                           reuse=tf.AUTO_REUSE,
                           regularizer=regularizer):

        # ================ Setup ================
        if beam_width > 1 and is_timeseries:
            raise ValueError('Invalid beam_width: {}'.format(beam_width))

        GO_SYMBOL = vocab_size
        END_SYMBOL = 0
        batch_size = tf.shape(encoder_outputs)[0]

        # ================ Projection ================
        # Project the encoder outputs to the size of the decoder state
        encoder_outputs_size = encoder_outputs.shape[-1]
        if projection and encoder_outputs_size != state_size:
            with tf.variable_scope('projection'):
                encoder_output_rank = len(encoder_outputs.shape)
                if encoder_output_rank > 2:
                    sequence_length = tf.shape(encoder_outputs)[1]
                    encoder_outputs = tf.reshape(encoder_outputs,
                                                 [-1, encoder_outputs_size])
                    encoder_outputs = fc_layer(encoder_outputs,
                                               encoder_outputs.shape[-1],
                                               state_size,
                                               activation=None,
                                               initializer=initializer)
                    encoder_outputs = tf.reshape(
                        encoder_outputs, [-1, sequence_length, state_size])
                else:
                    encoder_outputs = fc_layer(encoder_outputs,
                                               encoder_outputs.shape[-1],
                                               state_size,
                                               activation=None,
                                               initializer=initializer)

        # ================ Targets sequence ================
        # Calculate the length of inputs and the batch size
        with tf.variable_scope('sequence'):
            targets_sequence_length = sequence_length_2D(targets)
            start_tokens = tf.tile([GO_SYMBOL], [batch_size])
            end_tokens = tf.tile([END_SYMBOL], [batch_size])
            if is_timeseries:
                start_tokens = tf.cast(start_tokens, tf.float32)
                end_tokens = tf.cast(end_tokens, tf.float32)
            targets_with_go_and_eos = tf.concat([
                tf.expand_dims(start_tokens, 1), targets,
                tf.expand_dims(end_tokens, 1)
            ], 1)
            logging.debug(
                '  targets_with_go: {0}'.format(targets_with_go_and_eos))
            targets_sequence_length_with_eos = targets_sequence_length + 1  # the EOS symbol is 0 so it's not increasing the real length of the sequence

        # ================ Embeddings ================
        if is_timeseries:
            targets_embedded = tf.expand_dims(targets_with_go_and_eos, -1)
            targets_embeddings = None
        else:
            with tf.variable_scope('embedding'):
                if embeddings is not None:
                    embedding_size = embeddings.shape.as_list()[-1]
                    if tied_target_embeddings:
                        state_size = embedding_size
                elif tied_target_embeddings:
                    embedding_size = state_size

                if embeddings is not None:
                    embedding_go = tf.get_variable(
                        'embedding_GO',
                        initializer=tf.random_uniform([1, embedding_size],
                                                      -1.0, 1.0))
                    targets_embeddings = tf.concat([embeddings, embedding_go],
                                                   axis=0)
                else:
                    initializer_obj = get_initializer(initializer)
                    targets_embeddings = tf.get_variable(
                        'embeddings',
                        initializer=initializer_obj(
                            [vocab_size + 1, embedding_size]),
                        regularizer=regularizer)
                logging.debug(
                    '  targets_embeddings: {0}'.format(targets_embeddings))

                targets_embedded = tf.nn.embedding_lookup(
                    targets_embeddings,
                    targets_with_go_and_eos,
                    name='decoder_input_embeddings')
        logging.debug('  targets_embedded: {0}'.format(targets_embedded))

        # ================ Class prediction ================
        if tied_target_embeddings:
            class_weights = tf.transpose(targets_embeddings)
        else:
            initializer_obj = get_initializer(initializer)
            class_weights = tf.get_variable('class_weights',
                                            initializer=initializer_obj(
                                                [state_size, vocab_size + 1]),
                                            regularizer=regularizer)
        logging.debug('  class_weights: {0}'.format(class_weights))
        class_biases = tf.get_variable('class_biases', [vocab_size + 1])
        logging.debug('  class_biases: {0}'.format(class_biases))
        projection_layer = Projection(class_weights, class_biases)

        # ================ RNN ================
        initial_state = encoder_outputs
        with tf.variable_scope('rnn_cells') as vs:
            # Cell
            cell_fun = get_cell_fun(cell_type)

            if num_layers == 1:
                cell = cell_fun(state_size)
                if cell_type.startswith('lstm'):
                    initial_state = LSTMStateTuple(c=initial_state,
                                                   h=initial_state)
            elif num_layers > 1:
                cell = MultiRNNCell(
                    [cell_fun(state_size) for _ in range(num_layers)],
                    state_is_tuple=True)
                if cell_type.startswith('lstm'):
                    initial_state = LSTMStateTuple(c=initial_state,
                                                   h=initial_state)
                initial_state = tuple([initial_state] * num_layers)
            else:
                raise ValueError(
                    'num_layers in recurrent decoser: {}. '
                    'Number of layers in a recurrenct decoder cannot be <= 0'.
                    format(num_layers))

            # Attention
            if attention_mechanism is not None:
                if attention_mechanism == 'bahdanau':
                    attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
                        num_units=state_size,
                        memory=encoder_outputs,
                        memory_sequence_length=sequence_length_3D(
                            encoder_outputs))
                elif attention_mechanism == 'luong':
                    attention_mechanism = tf.contrib.seq2seq.LuongAttention(
                        num_units=state_size,
                        memory=encoder_outputs,
                        memory_sequence_length=sequence_length_3D(
                            encoder_outputs))
                else:
                    raise ValueError(
                        'Attention mechanism {} not supported'.format(
                            attention_mechanism))
                cell = tf.contrib.seq2seq.AttentionWrapper(
                    cell, attention_mechanism, attention_layer_size=state_size)
                initial_state = cell.zero_state(dtype=tf.float32,
                                                batch_size=batch_size)
                initial_state = initial_state.clone(
                    cell_state=reduce_sequence(encoder_outputs, 'last'))

            for v in tf.global_variables():
                if v.name.startswith(vs.name):
                    logging.debug('  {}: {}'.format(v.name, v))

        # ================ Decoding ================
        def decode(initial_state,
                   cell,
                   helper,
                   beam_width=1,
                   projection_layer=None):
            # The decoder itself
            if beam_width > 1:
                # Tile inputs for beam search decoder
                beam_initial_state = tf.contrib.seq2seq.tile_batch(
                    initial_state, beam_width)
                decoder = tf.contrib.seq2seq.BeamSearchDecoder(
                    cell=cell,
                    embedding=targets_embeddings,
                    start_tokens=start_tokens,
                    end_token=END_SYMBOL,
                    initial_state=beam_initial_state,
                    beam_width=beam_width,
                    output_layer=projection_layer)
            else:
                decoder = BasicDecoder(cell=cell,
                                       helper=helper,
                                       initial_state=initial_state,
                                       output_layer=projection_layer)

            # The decoding operation
            outputs = tf.contrib.seq2seq.dynamic_decode(
                decoder=decoder,
                output_time_major=False,
                impute_finished=False if beam_width > 1 else True,
                maximum_iterations=max_sequence_length)

            return outputs

        # ================ Decoding helpers ================
        if is_timeseries:
            train_helper = TimeseriesTrainingHelper(
                inputs=targets_embedded,
                sequence_length=targets_sequence_length_with_eos)
            final_outputs_pred, final_state_pred, final_sequence_lengths_pred = decode(
                initial_state,
                cell,
                train_helper,
                projection_layer=projection_layer)
            eval_logits = final_outputs_pred.rnn_output
            train_logits = final_outputs_pred.projection_input
            predictions_sequence = tf.reshape(eval_logits, [batch_size, -1])
            predictions_sequence_length_with_eos = final_sequence_lengths_pred

        else:
            train_helper = tf.contrib.seq2seq.TrainingHelper(
                inputs=targets_embedded,
                sequence_length=targets_sequence_length_with_eos)
            final_outputs_train, final_state_train, final_sequence_lengths_train = decode(
                initial_state,
                cell,
                train_helper,
                projection_layer=projection_layer)
            eval_logits = final_outputs_train.rnn_output
            train_logits = final_outputs_train.projection_input
            # train_predictions = final_outputs_train.sample_id

            pred_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                embedding=targets_embeddings,
                start_tokens=start_tokens,
                end_token=END_SYMBOL)
            final_outputs_pred, final_state_pred, final_sequence_lengths_pred = decode(
                initial_state,
                cell,
                pred_helper,
                beam_width,
                projection_layer=projection_layer)

            if beam_width > 1:
                predictions_sequence = final_outputs_pred.beam_search_decoder_output.predicted_ids[:, :,
                                                                                                   0]
                # final_outputs_pred..predicted_ids[:,:,0] would work too, but it contains -1s for padding
                predictions_sequence_scores = final_outputs_pred.beam_search_decoder_output.scores[:, :,
                                                                                                   0]
                predictions_sequence_length_with_eos = final_sequence_lengths_pred[:,
                                                                                   0]
            else:
                predictions_sequence = final_outputs_pred.sample_id
                predictions_sequence_scores = final_outputs_pred.rnn_output
                predictions_sequence_length_with_eos = final_sequence_lengths_pred

    logging.debug('  train_logits: {0}'.format(train_logits))
    logging.debug('  eval_logits: {0}'.format(eval_logits))
    logging.debug('  predictions_sequence: {0}'.format(predictions_sequence))
    logging.debug('  predictions_sequence_scores: {0}'.format(
        predictions_sequence_scores))

    return predictions_sequence, predictions_sequence_scores, predictions_sequence_length_with_eos, \
           targets_sequence_length_with_eos, eval_logits, train_logits, class_weights, class_biases