def _build(self, sequence, sequence_length, is_training=True):
        """Connect to the graph.

    Args:
      sequence: A [batch_size, max_sequence_length] tensor of int. For example
        the indices of words as sampled by the generator.
      sequence_length: A [batch_size] tensor of int. Length of the sequence.
      is_training: Boolean, False to disable dropout.

    Returns:
      A [batch_size, max_sequence_length, feature_size] tensor of floats. For
      each sequence in the batch, the features should (hopefully) allow to
      distinguish if the value at each timestep is real or generated.
    """
        batch_size, max_sequence_length = sequence.shape.as_list()
        keep_prob = (1.0 - self._dropout) if is_training else 1.0

        if self._embedding_source:
            all_embeddings = utils.make_partially_trainable_embeddings(
                self._vocab_file, self._embedding_source, self._vocab_size,
                self._trainable_embedding_size)
        else:
            all_embeddings = tf.get_variable(
                'trainable_embedding',
                shape=[self._vocab_size, self._trainable_embedding_size],
                trainable=True)
        _, self._embedding_size = all_embeddings.shape.as_list()
        input_embeddings = tf.nn.dropout(all_embeddings, keep_prob=keep_prob)
        embeddings = tf.nn.embedding_lookup(input_embeddings, sequence)
        embeddings.shape.assert_is_compatible_with(
            [batch_size, max_sequence_length, self._embedding_size])
        position_dim = 8
        embeddings_pos = utils.append_position_signal(embeddings, position_dim)
        embeddings_pos = tf.reshape(embeddings_pos, [
            batch_size * max_sequence_length,
            self._embedding_size + position_dim
        ])
        lstm_inputs = snt.Linear(self._feature_sizes[0])(embeddings_pos)
        lstm_inputs = tf.reshape(
            lstm_inputs,
            [batch_size, max_sequence_length, self._feature_sizes[0]])
        lstm_inputs.shape.assert_is_compatible_with(
            [batch_size, max_sequence_length, self._feature_sizes[0]])

        encoder_cells = []
        for feature_size in self._feature_sizes:
            encoder_cells += [
                snt.LSTM(feature_size, use_layer_norm=self._use_layer_norm)
            ]
        encoder_cell = snt.DeepRNN(encoder_cells)
        initial_state = encoder_cell.initial_state(batch_size)

        hidden_states, _ = tf.nn.dynamic_rnn(cell=encoder_cell,
                                             inputs=lstm_inputs,
                                             sequence_length=sequence_length,
                                             initial_state=initial_state,
                                             swap_memory=True)

        hidden_states.shape.assert_is_compatible_with(
            [batch_size, max_sequence_length,
             sum(self._feature_sizes)])
        logits = snt.BatchApply(snt.Linear(1))(hidden_states)
        logits.shape.assert_is_compatible_with(
            [batch_size, max_sequence_length, 1])
        logits_flat = tf.reshape(logits, [batch_size, max_sequence_length])

        # Mask past first PAD symbol
        #
        # Note that we still rely on tf.nn.bidirectional_dynamic_rnn taking
        # into account the sequence_length properly, because otherwise
        # the logits at a given timestep will depend on the inputs for all other
        # timesteps, including the ones that should be masked.
        mask = utils.get_mask_past_symbol(sequence, self._pad_token)
        masked_logits_flat = logits_flat * tf.cast(mask, tf.float32)
        return masked_logits_flat
Beispiel #2
0
    def _build(self, is_training=True, temperature=1.0):
        input_keep_prob = (1. - self._input_dropout) if is_training else 1.0
        output_keep_prob = (1. - self._output_dropout) if is_training else 1.0

        batch_size = self._batch_size
        max_sequence_length = self._max_sequence_length
        if self._embedding_source:
            all_embeddings = utils.make_partially_trainable_embeddings(
                self._vocab_file, self._embedding_source, self._vocab_size,
                self._trainable_embedding_size)
        else:
            all_embeddings = tf.get_variable(
                'trainable_embeddings',
                shape=[self._vocab_size, self._trainable_embedding_size],
                trainable=True)
        _, self._embedding_size = all_embeddings.shape.as_list()
        input_embeddings = tf.nn.dropout(all_embeddings,
                                         keep_prob=input_keep_prob)
        output_embeddings = tf.nn.dropout(all_embeddings,
                                          keep_prob=output_keep_prob)

        out_bias = tf.get_variable('out_bias',
                                   shape=[1, self._vocab_size],
                                   dtype=tf.float32)
        in_proj = tf.get_variable(
            'in_proj', shape=[self._embedding_size, self._feature_sizes[0]])
        # If more than 1 layer, then output has dim sum(self._feature_sizes),
        # which is different from input dim == self._feature_sizes[0]
        # So we need a different projection matrix for input and output.
        if len(self._feature_sizes) > 1:
            out_proj = tf.get_variable(
                'out_proj',
                shape=[self._embedding_size,
                       sum(self._feature_sizes)])
        else:
            out_proj = in_proj

        encoder_cells = []
        for feature_size in self._feature_sizes:
            encoder_cells += [
                snt.LSTM(feature_size, use_layer_norm=self._use_layer_norm)
            ]
        encoder_cell = snt.DeepRNN(encoder_cells)
        state = encoder_cell.initial_state(batch_size)

        # Manual unrolling.
        samples_list, logits_list, logprobs_list, embeddings_list = [], [], [], []
        sample = tf.tile(
            tf.constant(self._pad_token, dtype=tf.int32)[None], [batch_size])
        logging.info('Unrolling over %d steps.', max_sequence_length)
        for _ in range(max_sequence_length):
            # Input is sampled word at t-1.
            embedding = tf.nn.embedding_lookup(input_embeddings, sample)
            embedding.shape.assert_is_compatible_with(
                [batch_size, self._embedding_size])
            embedding_proj = tf.matmul(embedding, in_proj)
            embedding_proj.shape.assert_is_compatible_with(
                [batch_size, self._feature_sizes[0]])

            outputs, state = encoder_cell(embedding_proj, state)
            outputs_proj = tf.matmul(outputs, out_proj, transpose_b=True)
            logits = tf.matmul(
                outputs_proj, output_embeddings, transpose_b=True) + out_bias
            categorical = tfp.distributions.Categorical(logits=logits /
                                                        temperature)
            sample = categorical.sample()
            logprobs = categorical.log_prob(sample)

            samples_list.append(sample)
            logits_list.append(logits)
            logprobs_list.append(logprobs)
            embeddings_list.append(embedding)

        # Create an op to retrieve embeddings for full sequence, useful for testing.
        embeddings = tf.stack(  # pylint: disable=unused-variable
            embeddings_list,
            axis=1,
            name='embeddings')
        sequence = tf.stack(samples_list, axis=1)
        logprobs = tf.stack(logprobs_list, axis=1)

        # The sequence stops after the first occurrence of a PAD token.
        sequence_length = utils.get_first_occurrence_indices(
            sequence, self._pad_token)
        mask = utils.get_mask_past_symbol(sequence, self._pad_token)
        masked_sequence = sequence * tf.cast(mask, tf.int32)
        masked_logprobs = logprobs * tf.cast(mask, tf.float32)
        return {
            'sequence': masked_sequence,
            'sequence_length': sequence_length,
            'logprobs': masked_logprobs
        }