Exemplo n.º 1
0
    def create_pretrain_network(self):
        with tf.variable_scope('attention'):
            (self.attention_keys, self.attention_values, _,
             self.attention_construct_fn) = attention_utils.prepare_attention(
                 self.encoded_seed,
                 'luong',
                 num_units=self.hidden_dim,
                 reuse=True)

        # supervised pretraining for generator
        g_predictions = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                     size=self.sequence_length,
                                                     dynamic_size=False,
                                                     infer_shape=True)

        # processed for batch
        with tf.device("/cpu:0"):
            self.processed_x = tf.transpose(
                tf.nn.embedding_lookup(self.g_embeddings, self.x),
                perm=[1, 0, 2])  # seq_length x batch_size x emb_dim

        ta_emb_x = tensor_array_ops.TensorArray(dtype=tf.float32,
                                                size=self.sequence_length)
        ta_emb_x = ta_emb_x.unstack(self.processed_x)

        output_mask = None
        if self.is_training:
            output_mask = make_mask(self.batch_size, self.gen_vd_keep_prob,
                                    self.hidden_dim)

        def _pretrain_recurrence(i, x_t, prev_h_state, g_predictions):
            out, state = self.g_recurrent_unit(x_t, prev_h_state)
            out = self.attention_construct_fn(out, self.attention_keys,
                                              self.attention_values)

            if output_mask is not None:
                out *= output_mask

            o_t = self.g_output_unit(out)  # batch x vocab , logits not prob

            g_predictions = g_predictions.write(
                i, tf.nn.softmax(o_t))  # batch x vocab_size

            x_tp1 = ta_emb_x.read(i)
            return i + 1, x_tp1, state, g_predictions

        _, _, _, self.g_predictions = control_flow_ops.while_loop(
            cond=lambda i, _1, _2, _3: i < self.sequence_length,
            body=_pretrain_recurrence,
            loop_vars=(tf.constant(0, dtype=tf.int32),
                       tf.nn.embedding_lookup(self.g_embeddings,
                                              self.start_token), self.h0,
                       g_predictions))

        self.g_predictions = self.g_predictions.stack()
        self.g_predictions = tf.transpose(
            self.g_predictions,
            perm=[1, 0, 2])  # batch_size x seq_length x vocab_size

        print(self.g_predictions.shape)
Exemplo n.º 2
0
    def create(self, reuse=None):
        g_embeddings = self.create_embedding()
        init_missing_embedding = tf.random_uniform([1, self.emb_dim], -1.0,
                                                   1.0)
        missing_embedding = tf.get_variable("missing_embedding",
                                            initializer=init_missing_embedding)
        self.g_embeddings = tf.concat([g_embeddings, missing_embedding],
                                      axis=0)

        attn_cell = self.create_cell(reuse=reuse)
        with tf.variable_scope('seed_encoder') as scope:
            embedded_seed = tf.nn.embedding_lookup(self.g_embeddings,
                                                   self.masked_inputs)

            cell = tf.contrib.rnn.MultiRNNCell(
                [attn_cell() for _ in range(self.n_rnn_layers)],
                state_is_tuple=True)
            state = cell.zero_state(self.batch_size, tf.float32)

            outputs, state = tf.nn.dynamic_rnn(cell,
                                               embedded_seed,
                                               initial_state=state,
                                               scope=scope)

            def _make_mask(batch_size, keep_prob, units):
                random_tensor = keep_prob
                random_tensor += tf.random_uniform(
                    tf.stack([batch_size, 1, units]))
                return tf.floor(random_tensor) / keep_prob

            if self.is_training:
                output_mask = _make_mask(self.batch_size,
                                         self.gen_vd_keep_prob,
                                         self.hidden_dim)
                outputs *= output_mask

            self.encoded_seed = outputs
            # Initial states
            # self.h0 = self.g_recurrent_unit.zero_state(self.batch_size, tf.float32)
            self.h0 = state

        with tf.variable_scope('attention'):
            print("encoded seed:", self.encoded_seed.shape)
            (self.attention_keys, self.attention_values, _,
             self.attention_construct_fn) = attention_utils.prepare_attention(
                 self.encoded_seed, 'luong', num_units=self.hidden_dim)

        with tf.variable_scope('generator'):
            self.g_recurrent_unit = tf.contrib.rnn.MultiRNNCell(
                [attn_cell() for _ in range(self.n_rnn_layers)],
                state_is_tuple=True)
            self.g_output_unit = self.create_output_unit(
            )  # maps h_t to o_t (output token logits)
Exemplo n.º 3
0
def dis_decoder(generator,
                embedding,
                sequence,
                encoding_state,
                reuse=None,
                dis_num_layers=2):
    sequence = tf.cast(sequence, tf.int32)

    with tf.variable_scope('decoder', reuse=reuse):

        def lstm_cell():
            return tf.contrib.rnn.BasicLSTMCell(generator.hidden_dim,
                                                forget_bias=0.0,
                                                state_is_tuple=True,
                                                reuse=reuse)

        attn_cell = lstm_cell
        if generator.is_training and generator.dis_vd_keep_prob < 1:

            def attn_cell():
                return VariationalDropoutWrapper(lstm_cell(),
                                                 generator.batch_size,
                                                 generator.dis_vd_keep_prob,
                                                 generator.dis_vd_keep_prob)

        cell_dis = tf.contrib.rnn.MultiRNNCell(
            [attn_cell() for _ in range(dis_num_layers)], state_is_tuple=True)

        state = encoding_state[1]

        (attention_keys, attention_values, _,
         attention_construct_fn) = attention_utils.prepare_attention(
             encoding_state[0],
             'luong',
             num_units=generator.hidden_dim,
             reuse=reuse)

        if generator.is_training:
            output_mask = make_mask(generator.batch_size,
                                    generator.dis_vd_keep_prob,
                                    generator.hidden_dim)

        with tf.variable_scope('rnn') as vs:
            predictions = []

            rnn_inputs = tf.nn.embedding_lookup(embedding, sequence)

            for t in range(generator.sequence_length):
                if t > 0:
                    tf.get_variable_scope().reuse_variables()

                rnn_in = rnn_inputs[:, t]
                rnn_out, state = cell_dis(rnn_in, state)
                rnn_out = attention_construct_fn(rnn_out, attention_keys,
                                                 attention_values)
                if generator.is_training:
                    rnn_out *= output_mask

                pred = tf.contrib.layers.linear(rnn_out, 1, scope=vs)
                predictions.append(pred)
        predictions = tf.stack(predictions, axis=1)
    return tf.squeeze(predictions, axis=2)