示例#1
0
    def __init__(self, embedding, start_tokens, end_token, vocab,
                 reward_metric, ground_truth, ground_truth_length, lambdas):
        SampleEmbeddingHelper.__init__(self, embedding, start_tokens,
                                       end_token)

        self._vocab = vocab
        self._ground_truth = ground_truth
        self._lambdas = lambdas
        self._ground_truth_length = ground_truth_length
        self._metric = reward_metric
示例#2
0
    def sample(self, time, outputs, state, name=None):
        """
        sample tokens for next step, notice the special form
        of 'state'([decoded_ids, rnn_state])
        """
        sample_method_sampler = \
            tfpd.Categorical(probs=self._lambdas)
        sample_method_id = sample_method_sampler.sample()

        truth_feeding = lambda: tf.cond(
            tf.less(time, tf.shape(self._ground_truth)[1]),
            lambda: tf.cast(self._ground_truth[:, time], tf.int32),
            lambda: tf.ones_like(self._ground_truth[:, 0],
                                 dtype=tf.int32) * self._vocab.eos_token_id)

        self_feeding = lambda: SampleEmbeddingHelper.sample(
            self, time, outputs, state, name)

        reward_feeding = lambda: self._sample_by_reward(time, state)

        sample_ids = tf.cond(
            tf.logical_or(tf.equal(time, 0), tf.equal(sample_method_id, 1)),
            truth_feeding,
            lambda: tf.cond(
                tf.equal(sample_method_id, 2),
                reward_feeding,
                self_feeding))
        return sample_ids
示例#3
0
    def next_inputs(self, time, outputs, state, sample_ids, name=None):
        """
        notice the special form of 'state'([decoded_ids, rnn_state])
        """
        finished, next_inputs, next_state = SampleEmbeddingHelper.next_inputs(
            self, time, outputs, state[1], sample_ids, name)

        next_state = [tf.concat(
            [state[0][:, :time], tf.expand_dims(sample_ids, 1),
             state[0][:, time + 1:]], axis=1), next_state]
        next_state[0] = tf.reshape(next_state[0], (tf.shape(sample_ids)[0], 60))

        return finished, next_inputs, next_state
示例#4
0
    def _init(self, sequence, targets, authors):
        batch_size = tf.shape(sequence)[0]

        sequence_lengths = tf.cast(tf.count_nonzero(sequence, axis=1), tf.int32)
        embedding = tf.Variable(
            tf.random_normal((self._vocab_size, self._embed_size)),
            name='char_embedding'
        )
        context = tf.Variable(
            tf.random_normal((self._author_size, self._ctx_size)),
            name='ctx_embedding'
        )

        embedded_sequence = tf.nn.embedding_lookup(embedding, sequence)
        embedded_authors = tf.nn.embedding_lookup(context, authors)

        gpu = lambda x: '/gpu:{}'.format(x % self._num_gpu)

        if self._training:
            dropout = lambda x: DropoutWrapper(
                x, 1.0-self._input_dropout, 1.0-self._output_dropout)
            helper = TrainingHelper(embedded_sequence, sequence_lengths)
        else:
            dropout = lambda x: x
            helper = SampleEmbeddingHelper(embedding, sequence[:,0], 2)

        base = lambda x: ContextWrapper(self._cell(x), embedded_authors)
        wrap = lambda i, cell: DeviceWrapper(dropout(cell), gpu(i))
        cells = [wrap(i, base(self._cell_size)) for i in range(self._cell_num)]
        cell = MultiRNNCell(cells)

        init_state = cell.zero_state(batch_size, tf.float32)
        dense = tf.layers.Dense(
            self._vocab_size, self._activation, name='fully_connected'
        )
        decoder = BasicDecoder(cell, helper, init_state, dense)
        output, _, _ = dynamic_decode(decoder, swap_memory=True)
        logits = output.rnn_output

        weights = tf.sequence_mask(sequence_lengths, dtype=tf.float32)
        loss = tf.contrib.seq2seq.sequence_loss(
            logits,
            targets,
            weights
        )

        out = output.sample_id

        return targets, loss, out
示例#5
0
    def _init(self):
        sequence = tf.placeholder(tf.int32, [None, None], name='sequence')
        targets = tf.placeholder(tf.int32, [None, None], name='targets')
        authors = tf.placeholder(tf.int32, [None, None], name='authors')

        batch_size = tf.shape(sequence)[0]

        sequence_lengths = tf.cast(tf.count_nonzero(sequence, axis=1),
                                   tf.int32)
        embedding = tf.Variable(
            tf.random_normal((self._vocab_size, self._embed_size)))
        context = tf.Variable(
            tf.random_normal((self._author_size, self._ctx_size)))

        embedded_sequence = tf.nn.embedding_lookup(embedding, sequence)
        embedded_authors = tf.nn.embedding_lookup(context, authors)
        one_hot_targets = tf.one_hot(targets, self._vocab_size)

        gpu = lambda x: str(x % self._num_gpu)

        if self._attn:
            mech = BahdanauAttention(self._attn_depth, embedded_sequence,
                                     sequence_lengths)
            attn_cell = lambda x: DeviceWrapper(
                AttentionWrapper(x, mech, self._attn_size), "/gpu:" + gpu(1))
        else:
            attn_cell = lambda x: x

        if self._training:
            dropout = lambda x: DropoutWrapper(x, 1.0, 1.0 - self._dropout)
        else:
            dropout = lambda x: x

        if self._cell == 'lstm':
            base_cell = lambda x: dropout(BasicLSTMCell(x))
        elif self._cell == 'gru':
            base_cell = lambda x: dropout(GRUCell(x))

        context_cell = ContextWrapper(
            base_cell(self._cell_size),
            embedded_authors,
        )
        #context_cell = base_cell(self._cell_size)
        bottom_cell = DeviceWrapper(attn_cell(context_cell), "/gpu:0")
        top_cells = [
            DeviceWrapper(base_cell(self._cell_size), "/gpu:" + gpu(i))
            for i in range(1, self._cell_num)
        ]
        cell = MultiRNNCell([bottom_cell] + top_cells)

        init_state = cell.zero_state(batch_size, tf.float32)

        if self._training:
            helper = TrainingHelper(embedded_sequence, sequence_lengths)
        else:
            helper = SampleEmbeddingHelper(embedding, sequence[:, 0], 1)

        dense = Dense(self._vocab_size, self._activation)
        decoder = BasicDecoder(cell, helper, init_state, dense)
        output, state, _ = dynamic_decode(decoder, swap_memory=True)
        logits = output.rnn_output

        loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits,
                                                       labels=one_hot_targets)
        loss = tf.reduce_mean(loss)

        out = tf.nn.softmax(logits)

        return sequence, authors, targets, loss, out