def setup_network(self):
        # Setup character embedding
        embedded_encoder_input, embedded_decoder_input, embed_func = self.setup_character_embedding(
        )

        # Output projection
        with tf.variable_scope('alphabet_projection') as scope:
            self.projection_W, self.projection_b = intialize_projections(
                input_size=4 * self.config.
                num_units,  # We use bidirectional encoder + attention
                output_size=self.config.alphabet_size,
                scope=scope)

            # Define alphabet projection function
            def project_func(output):
                return projection(output,
                                  W=self.projection_W,
                                  b=self.projection_b)

        # Encoder
        with tf.variable_scope('encoder') as scope:
            enc_outputs, enc_final_state = self.encoder.encode(
                inputs=embedded_encoder_input,
                seq_lengths=self.encoder_sequence_length,
                # enc_word_indices=self.enc_word_indices,
                # word_seq_lengths=self.word_seq_lengths,
                # max_words=self.config.max_words,
                scope=scope)

        # Set decoder initial state and encoder outputs based on the binary
        # mode input value
        # - If `self.is_lm_mode=0` Use the passed initial state from encoder
        # - If `self.is_lm_mode=1` Use the zero vector
        self.enc_outputs, self.enc_final_state = select_decoder_inputs(
            is_lm_mode=self.is_lm_mode,
            enc_outputs=enc_outputs,
            initial_state=enc_final_state,
        )

        # Pack state to tensor
        self.enc_final_state_tensor = pack_state_tuple(self.enc_final_state)

        # Initialize decoder attention function using encoder outputs
        self.decoder.initialize_attention_func(
            input_size=embedded_decoder_input.get_shape().as_list()[-1],
            attention_states=self.enc_outputs)

        # Define initial attention tensor
        self.initial_attention = self.decoder.attention_func(
            self.enc_final_state)

        # Define decoder
        with tf.variable_scope('decoder'):
            dec_outputs, dec_final_state = self.decoder.decode(
                inputs=embedded_decoder_input,
                initial_state=self.enc_final_state,
                seq_length=self.decoder_sequence_length,
                embed_func=embed_func,
                project_func=project_func,
            )

            # Project output to alphabet size and reshape
            dec_outputs = tf.reshape(dec_outputs,
                                     [-1, 4 * self.config.num_units])
            dec_outputs = projection(dec_outputs,
                                     W=self.projection_W,
                                     b=self.projection_b)
            dec_outputs = tf.reshape(dec_outputs, [
                -1, self.config.max_dec_seq_length + 1,
                self.config.alphabet_size
            ])

        if self.prediction_mode:
            dec_outputs = self.decoder_logits

        # Define loss
        self.setup_losses(dec_outputs=dec_outputs,
                          target_chars=self.target_chars,
                          decoder_sequence_length=self.decoder_sequence_length)

        if self.prediction_mode:
            # Look up inputs
            decoder_inputs_embedded = tf.nn.embedding_lookup(
                self.embedding_matrix,
                self.decoder_inputs,
                name='decoder_input')
            is_lm_mode_tensor = tf.to_float(
                tf.expand_dims(self.is_lm_mode, axis=1))
            decoder_inputs = tf.concat(
                [decoder_inputs_embedded, is_lm_mode_tensor], axis=1)

            # Unpack state
            initial_state = unpack_state_tensor(self.decoder_state)

            with tf.variable_scope('decoder', reuse=True):
                decoder_output, decoder_final_state, self.decoder_new_attention = self.decoder.predict(
                    inputs=decoder_inputs,
                    initial_state=initial_state,
                    attention_states=self.decoder_attention)

            # Project output to alphabet size
            self.decoder_output = projection(decoder_output,
                                             W=self.projection_W,
                                             b=self.projection_b,
                                             name='decoder_output')

            # Compute decayed logits
            self.decoder_probs_decayed = compute_decayed_probs(
                logits=self.decoder_output,
                decay_parameter_ph=self.probs_decay_parameter)

            # Pack state to tensor
            self.decoder_final_state = pack_state_tuple(
                decoder_final_state, name='decoder_final_state')
Exemple #2
0
    def predict(self, session, lm_predict_func=None, **kwargs):
        assert self.prediction_mode

        def decode_func(inputs, state):
            output, probs, state = session.run(fetches=[
                self.decoder_output, self.decoder_probs,
                self.decoder_final_state
            ],
                                               feed_dict={
                                                   self.decoder_inputs: inputs,
                                                   self.decoder_state: state
                                               })

            if lm_predict_func is not None:
                lm_output, lm_probs, lm_state = lm_predict_func(inputs, state)

            return {'output': output, 'probs': probs, 'state': state}

        def loss_func(logits, targets, input_length):
            return session.run(
                fetches=[self.mean_loss_batch, self.mean_prob_x_batch],
                feed_dict={
                    self.decoder_logits: logits,
                    self.target_chars: targets,
                    self.decoder_sequence_length: input_length
                })

        # Construct vector of <GO_ID> tokens as initial input
        batch_size = kwargs['enc_input'].shape[0]
        dec_target = kwargs['dec_target']
        dec_input_length = kwargs['dec_input_length']
        initial_inputs = np.full(shape=(batch_size, ),
                                 fill_value=self.alphabet.GO_ID,
                                 dtype=np.float32)
        max_iterations = self.config.max_dec_seq_length + 1

        # Define initial state
        initial_state = self.decoder.cell.zero_state(batch_size,
                                                     dtype=tf.float32)
        initial_state = pack_state_tuple(initial_state)
        initial_state = session.run(initial_state)

        extra_features = {}

        # Initialize predictor
        if self.sample_type == 'beam':
            predictor = BeamSearchPredictor(batch_size=batch_size,
                                            max_length=max_iterations,
                                            alphabet=self.alphabet,
                                            decode_func=decode_func,
                                            loss_func=loss_func,
                                            beam_size=self.beam_size)
        elif self.sample_type == 'sample':
            predictor = SamplingPredictor(
                batch_size=batch_size,
                max_length=max_iterations,
                alphabet=self.alphabet,
                decode_func=decode_func,
                loss_func=loss_func,
                num_samples=self.beam_size,
            )
        else:
            raise KeyError('Invalid sample_type provided!')

        # Predict sequence candidates
        final_candidates, final_logits, loss_candidates, prob_x_candidates = predictor.predict_sequences(
            initial_state=initial_state,
            target=dec_target,
            input_length=dec_input_length,
            features=extra_features)

        # Remove predictions after the `<EOS>` id
        for i, j, k in zip(*np.where(
                final_candidates == self.alphabet.EOS_ID)):
            final_candidates[i, j, k + 1:] = 0

        return {
            'candidates': final_candidates,
            'loss_candidates': loss_candidates,
            'prob_x_candidates': prob_x_candidates
        }
Exemple #3
0
    def setup_network(self):
        # Setup character embedding (defines `self.embedding_matrix`)
        with tf.device('/cpu:0'), tf.variable_scope(name_or_scope='embedding'):
            self.embedding_matrix = tf.get_variable(
                shape=[self.config.alphabet_size, self.config.embedding_size],
                initializer=tf.contrib.layers.xavier_initializer(),
                name='W')

            # Gather slices from `params` according to `indices`
            embedded_decoder_input = tf.nn.embedding_lookup(
                self.embedding_matrix,
                self.decoder_input_chars,
                name='dec_input')

            def embed_func(input_chars):
                return tf.gather(self.embedding_matrix, input_chars)

        # Output projection
        with tf.variable_scope('alphabet_projection') as scope:
            self.projection_W, self.projection_b = intialize_projections(
                input_size=self.config.num_units,
                output_size=self.config.alphabet_size,
                scope=scope)

            # Define alphabet projection function
            def project_func(output):
                return projection(output,
                                  W=self.projection_W,
                                  b=self.projection_b)

        # Define initial state as zero states
        self.enc_final_state = self.decoder.cell.zero_state(
            batch_size=tf.shape(embedded_decoder_input)[0], dtype=tf.float32)

        # Define decoder
        with tf.variable_scope('decoder'):
            dec_outputs, dec_final_state = self.decoder.decode(
                inputs=embedded_decoder_input,
                initial_state=self.enc_final_state,
                seq_length=self.decoder_sequence_length,
                embed_func=embed_func,
                project_func=project_func)

            # Project output to alphabet size and reshape
            dec_outputs = tf.reshape(dec_outputs, [-1, self.config.num_units])
            dec_outputs = projection(dec_outputs,
                                     W=self.projection_W,
                                     b=self.projection_b)
            dec_outputs = tf.reshape(dec_outputs, [
                -1, self.config.max_dec_seq_length + 1,
                self.config.alphabet_size
            ])

            # self.packed_dec_final_state = pack_state_tuple(dec_final_state)

        if self.prediction_mode:
            dec_outputs = self.decoder_logits

        # Define loss
        self.setup_losses(dec_outputs=dec_outputs,
                          target_chars=self.target_chars,
                          decoder_sequence_length=self.decoder_sequence_length)

        if self.prediction_mode:
            # Pack state to tensor
            self.enc_final_state_tensor = pack_state_tuple(
                self.enc_final_state)

            # Look up inputs
            decoder_inputs_embedded = tf.nn.embedding_lookup(
                self.embedding_matrix,
                self.decoder_inputs,
                name='decoder_input')

            # Unpack state
            initial_state = unpack_state_tensor(self.decoder_state)

            with tf.variable_scope('decoder', reuse=True):
                decoder_output, decoder_final_state = self.decoder.predict(
                    inputs=decoder_inputs_embedded,
                    initial_state=initial_state)

                # Project output to alphabet size
                self.decoder_output = projection(decoder_output,
                                                 W=self.projection_W,
                                                 b=self.projection_b,
                                                 name='decoder_output')
                self.decoder_probs = tf.nn.softmax(self.decoder_output,
                                                   name='decoder_probs')
                self.probs_decay_parameter = tf.placeholder(
                    tf.float64, shape=(), name='probs_decay_parameter')
                self.decoder_probs_decayed = tf.pow(
                    tf.cast(self.decoder_probs, tf.float64),
                    self.probs_decay_parameter)
                decoder_probs_sum = tf.expand_dims(tf.reduce_sum(
                    self.decoder_probs_decayed, axis=1),
                                                   axis=1)
                decoder_probs_sum = tf.tile(decoder_probs_sum,
                                            [1, self.config.alphabet_size])
                self.decoder_probs_decayed = self.decoder_probs_decayed / decoder_probs_sum

                # Pack state to tensor
                self.decoder_final_state = pack_state_tuple(
                    decoder_final_state, name='decoder_final_state')
    def setup_network(self):
        # Setup character embedding
        embedded_encoder_input, embedded_decoder_input, embed_func = self.setup_character_embedding(
        )

        # Output projection
        with tf.variable_scope('alphabet_projection') as scope:
            self.projection_W, self.projection_b = intialize_projections(
                input_size=4 * self.config.num_units,
                output_size=self.config.alphabet_size,
                scope=scope)

            # Define alphabet projection function
            def project_func(output):
                return projection(output,
                                  W=self.projection_W,
                                  b=self.projection_b)

        # Encoder
        with tf.variable_scope('encoder') as scope:
            # Normalize batch
            embedded_encoder_input = tf.layers.batch_normalization(
                inputs=embedded_encoder_input,
                center=True,
                scale=True,
                # training=not self.prediction_mode,
                training=
                True,  # I think this should be true always, because in training
                # and inference we have the entire question text.
                trainable=True,
            )

            enc_outputs, enc_final_state = self.encoder.encode(
                inputs=embedded_encoder_input,
                seq_lengths=self.encoder_sequence_length,
                enc_word_indices=self.enc_word_indices,
                word_seq_lengths=self.word_seq_lengths,
                max_words=self.config.max_words,
                scope=scope)

        # Predict question categories
        with tf.variable_scope('question') as scope:
            # Convert StateTuple to vector
            state_vector = tf.concat(flatten(enc_final_state),
                                     axis=1,
                                     name='combined-state-vec')

            # Add dense layer
            W, b = intialize_projections(input_size=4 * self.config.num_units *
                                         self.config.num_cells,
                                         output_size=128)
            layer = tf.nn.relu(tf.matmul(state_vector, W) + b)
            if self.add_dropout:
                layer = tf.nn.dropout(x=layer, keep_prob=self.keep_prob_ph)

            # Compute L2-weight decay
            W_penalty = tf.contrib.layers.apply_regularization(
                regularizer=tf.contrib.layers.l2_regularizer(
                    scale=self.config.W_lambda),
                weights_list=[W])

            class_logits = projection(x=layer,
                                      input_size=128,
                                      output_size=self.config.num_classes)

        # Set decoder initial state and encoder outputs based on the binary
        # mode input value
        # - If `self.is_lm_mode=0` Use the passed initial state from encoder
        # - If `self.is_lm_mode=1` Use the zero vector
        self.enc_outputs, enc_final_state = select_decoder_inputs(
            is_lm_mode=self.is_lm_mode,
            enc_outputs=enc_outputs,
            initial_state=enc_final_state,
        )

        # If an observation has a class -> Pass the true class as 1-hot-encoded
        # vector to the decoder input.
        # If an observation doesn't have a class -> Pass the class logits for
        # the given observation to the decoder input.
        class_is_known = tf.greater_equal(self.class_idx, 0)

        # Create one-hot-encoded vectors
        class_one_hot = tf.one_hot(indices=self.class_idx,
                                   depth=self.config.num_classes,
                                   on_value=1.0,
                                   off_value=0.0,
                                   axis=-1,
                                   dtype=tf.float32,
                                   name='class-one-hot-encoded')

        # Compute class probabilities
        class_probs = tf.nn.softmax(class_logits)

        # Select what to pass on
        self.class_info_vec = tf.where(condition=class_is_known,
                                       x=class_one_hot,
                                       y=class_probs)

        # Concatenate class info vector with decoder input
        _class_info_vec = tf.expand_dims(self.class_info_vec, axis=1)
        _class_info_vec = tf.tile(
            _class_info_vec,
            multiples=[1, self.config.max_dec_seq_length + 1, 1])
        decoder_input = tf.concat([embedded_decoder_input, _class_info_vec],
                                  axis=2)

        # Pack state to tensor
        self.enc_final_state_tensor = pack_state_tuple(enc_final_state)

        # Initialize decoder attention function using encoder outputs
        self.decoder.initialize_attention_func(
            input_size=decoder_input.get_shape().as_list()[-1],
            attention_states=self.enc_outputs)

        # Define decoder
        with tf.variable_scope('decoder'):
            dec_outputs, dec_final_state = self.decoder.decode(
                inputs=decoder_input,
                initial_state=enc_final_state,
                seq_length=self.decoder_sequence_length,
                embed_func=embed_func,
                project_func=project_func)

            # Project output to alphabet size and reshape
            dec_outputs = tf.reshape(dec_outputs,
                                     [-1, 4 * self.config.num_units])
            dec_outputs = projection(dec_outputs,
                                     W=self.projection_W,
                                     b=self.projection_b)
            dec_outputs = tf.reshape(dec_outputs, [
                -1, self.config.max_dec_seq_length + 1,
                self.config.alphabet_size
            ])

        if self.prediction_mode:
            dec_outputs = self.decoder_logits

        # Define loss
        self.setup_losses(dec_outputs=dec_outputs,
                          target_chars=self.target_chars,
                          decoder_sequence_length=self.decoder_sequence_length,
                          class_probs=class_probs,
                          class_idx=self.class_idx,
                          class_is_known=class_is_known,
                          class_one_hot=class_one_hot,
                          W_penalty=W_penalty)

        if self.prediction_mode:
            # Define initial attention tensor
            self.initial_attention = self.decoder.attention_func(
                enc_final_state)

            # Look up inputs
            decoder_inputs_embedded = tf.nn.embedding_lookup(
                self.embedding_matrix,
                self.decoder_inputs,
                name='decoder_input')
            is_lm_mode_tensor = tf.to_float(
                tf.expand_dims(self.is_lm_mode, axis=1))
            decoder_inputs = tf.concat(
                [decoder_inputs_embedded, is_lm_mode_tensor], axis=1)

            # Concatenate class info vector
            decoder_inputs = tf.concat([decoder_inputs, self.class_info_vec],
                                       axis=1)

            # Unpack state
            initial_state = unpack_state_tensor(self.decoder_state)

            with tf.variable_scope('decoder', reuse=True):
                decoder_output, decoder_final_state, self.decoder_new_attention = self.decoder.predict(
                    inputs=decoder_inputs,
                    initial_state=initial_state,
                    attention_states=self.decoder_attention)

            # Project output to alphabet size
            self.decoder_output = projection(decoder_output,
                                             W=self.projection_W,
                                             b=self.projection_b,
                                             name='decoder_output')

            # Compute decayed logits
            self.decoder_probs_decayed = compute_decayed_probs(
                logits=self.decoder_output,
                decay_parameter_ph=self.probs_decay_parameter)

            # Pack state to tensor
            self.decoder_final_state = pack_state_tuple(
                decoder_final_state, name='decoder_final_state')