コード例 #1
0
 def concat_and_gather_tuple_states(pc_states, c_state):
     rc_states = (
         tf.concat(axis=1, values=[pc_states[0], tf.expand_dims(c_state[0], 1)]),
         tf.concat(axis=1, values=[pc_states[1], tf.expand_dims(c_state[1], 1)])
     )
     c_states = (
         nest_map(lambda element: tf.gather(element, parent_refs), rc_states[0]),
         nest_map(lambda element: tf.gather(element, parent_refs), rc_states[1])
     )
     return c_states
コード例 #2
0
    def _tile_along_beam(cls, beam_size, state):
        if nest.is_sequence(state):
            return nest_map(
                lambda val: cls._tile_along_beam(beam_size, val),
                state
            )

        if not isinstance(state, tf.Tensor):
            raise ValueError("State should be a sequence or tensor")

        tensor = state

        tensor_shape = tensor.get_shape().with_rank_at_least(1)

        try:
            new_first_dim = tensor_shape[0] * beam_size
        except:
            new_first_dim = None

        dynamic_tensor_shape = tf.unstack(tf.shape(tensor))
        res = tf.expand_dims(tensor, 1)
        res = tf.tile(res, [1, beam_size] + [1] * (tensor_shape.ndims-1))
        res = tf.reshape(res, [-1] + list(dynamic_tensor_shape[1:]))
        res.set_shape([new_first_dim] + list(tensor_shape[1:]))
        return res
コード例 #3
0
ファイル: beam_search.py プロジェクト: sxdkxgwan/awesome_nmt
    def _create_state(self, batch_size, dtype, cell_state=None):
        cand_symbols = tf.fill([batch_size, 1],
                               tf.constant(self.start_token, dtype=tf.int32))
        cand_logprobs = tf.ones(
            (batch_size, ), dtype=tf.float32) * -float('inf')

        if cell_state is None:
            cell_state = self.cell.zero_state(batch_size * self.beam_size,
                                              dtype=dtype)
        else:
            cell_state = BeamDecoder._tile_along_beam(self.beam_size,
                                                      cell_state)
        full_size = batch_size * self.beam_size
        first_in_beam_mask = tf.equal(tf.range(full_size) % self.beam_size, 0)

        beam_symbols = tf.fill([full_size, 1],
                               tf.constant(self.start_token, dtype=tf.int32))
        beam_logprobs = tf.where(
            first_in_beam_mask,
            tf.fill([full_size], 0.0),
            tf.fill([full_size], -1e18),  # top_k does not play well with -inf
            # TODO: dtype-dependent value here
        )

        return (cand_symbols, cand_logprobs, beam_symbols, beam_logprobs,
                nest_map(lambda element: tf.expand_dims(element, 1),
                         cell_state))
コード例 #4
0
    def __call__(self, cell_inputs, state, scope=None):
        (
            past_beam_symbols,      # [batch_size*self.beam_size, :], right-aligned!!!
            past_beam_logprobs,     # [batch_size*self.beam_size]
            past_cell_states        # LSTM: ([batch_size*self.beam_size, :, dim],
                                    #        [batch_size*self.beam_size, :, dim])
                                    # GRU: [batch_size*self.beam_size, :, dim]
        ) = state

        past_cell_state = self.get_last_cell_state(past_cell_states)
        if self.use_copy and self.copy_fun == 'copynet':
            cell_output, cell_state, alignments, attns = \
                self.cell(cell_inputs, past_cell_state, scope)
        elif self.use_attention:
            cell_output, cell_state, alignments, attns = \
                self.cell(cell_inputs, past_cell_state, scope)
        else:
            cell_output, cell_state = \
                self.cell(cell_inputs, past_cell_state, scope)

        # [batch_size*beam_size, num_classes]
        if self.use_copy and self.copy_fun == 'copynet':
            logprobs = tf.log(cell_output)
        else:
            W, b = self.output_project
            if self.locally_normalized:
                logprobs = tf.nn.log_softmax(tf.matmul(cell_output, W) + b)
            else:
                logprobs = tf.matmul(cell_output, W) + b
        num_classes = logprobs.get_shape()[1].value

        # stop_mask: indicates partial sequences ending with a stop token
        # [batch_size * beam_size]
        # x     0
        # _STOP 1
        # x     0
        # x     0
        input_symbols = past_beam_symbols[:, -1]
        stop_mask = tf.expand_dims(tf.cast(
            tf.equal(input_symbols, self.stop_token), tf.float32), 1)

        # done_mask: indicates stop token in the output vocabulary
        # [1, num_classes]
        # [- - _STOP - - -]
        # [0 0 1 0 0 0]
        done_mask = tf.cast(tf.reshape(tf.equal(tf.range(num_classes),
                                                self.stop_token),
                                       [1, num_classes]),
                            tf.float32)
        # set the next token distribution of partial sequences ending with
        # a stop token to:
        # [- - _STOP - - -]
        # [-inf -inf 0 -inf -inf -inf]
        logprobs = tf.add(logprobs, tf.multiply(
            stop_mask, -1e18 * (tf.ones_like(done_mask) - done_mask)))
        logprobs = tf.multiply(logprobs, (1 - tf.multiply(stop_mask, done_mask)))

        # length normalization
        past_logprobs_unormalized = \
            tf.multiply(past_beam_logprobs, tf.pow(self.seq_len, self.alpha))
        logprobs_unormalized = \
            tf.expand_dims(past_logprobs_unormalized, 1) + logprobs
        seq_len = tf.expand_dims(self.seq_len, 1) + (1 - stop_mask)
        logprobs_batched = tf.div(logprobs_unormalized, tf.pow(seq_len, self.alpha))

        beam_logprobs, indices = tf.nn.top_k(
            tf.reshape(logprobs_batched, [-1, self.beam_size * num_classes]),
            self.beam_size
        )
        beam_logprobs = tf.reshape(beam_logprobs, [-1])

        # For continuing to the next symbols
        parent_refs_offsets = \
                (tf.range(self.full_size) // self.beam_size) * self.beam_size
        symbols = indices % num_classes # [batch_size, self.beam_size]
        parent_refs = tf.reshape(indices // num_classes, [-1]) # [batch_size*self.beam_size]
        parent_refs = parent_refs + parent_refs_offsets

        beam_symbols = tf.concat(axis=1, values=[tf.gather(past_beam_symbols, parent_refs),
                                                 tf.reshape(symbols, [-1, 1])])
        self.seq_len = tf.squeeze(tf.gather(seq_len, parent_refs), squeeze_dims=[1])

        if self.use_attention:
            ranked_alignments = nest_map(
                lambda element: tf.gather(element, parent_refs), alignments)
            ranked_attns = nest_map(
                lambda element: tf.gather(element, parent_refs), attns)

        # update cell_states
        def concat_and_gather_tuple_states(pc_states, c_state):
            rc_states = (
                tf.concat(axis=1, values=[pc_states[0], tf.expand_dims(c_state[0], 1)]),
                tf.concat(axis=1, values=[pc_states[1], tf.expand_dims(c_state[1], 1)])
            )
            c_states = (
                nest_map(lambda element: tf.gather(element, parent_refs), rc_states[0]),
                nest_map(lambda element: tf.gather(element, parent_refs), rc_states[1])
            )
            return c_states

        if nest.is_sequence(cell_state):
            if self.num_layers > 1:
                ranked_cell_states = [concat_and_gather_tuple_states(pc_states, c_state)
                    for pc_states, c_state in zip(past_cell_states, cell_state)]
            else:
                ranked_cell_states = concat_and_gather_tuple_states(
                    past_cell_states, cell_state)
        else:
            ranked_cell_states = tf.gather(
                tf.concat(axis=1, values=[past_cell_states, tf.expand_dims(cell_state, 1)]),
                parent_refs)

        compound_cell_state = (
            beam_symbols,
            beam_logprobs,
            ranked_cell_states
        )
        ranked_cell_output = tf.gather(cell_output, parent_refs)

        if self.use_copy and self.copy_fun == 'copynet':
            return ranked_cell_output, compound_cell_state, ranked_alignments, \
                   ranked_attns
        elif self.use_attention:
            return ranked_cell_output, compound_cell_state, ranked_alignments, \
                   ranked_attns
        else:
            return ranked_cell_output, compound_cell_state
コード例 #5
0
ファイル: beam_search.py プロジェクト: sxdkxgwan/awesome_nmt
    def __call__(self, inputs, state, scope=None):
        (
            past_cand_symbols,  # [batch_size, :]
            past_cand_logprobs,  # [batch_size]
            past_beam_symbols,  # [batch_size*self.beam_size, :], right-aligned!!!
            past_beam_logprobs,  # [batch_size*self.beam_size]
            past_cell_states,  # LSTM: ([batch_size*self.beam_size, :, dim],
            #        [batch_size*self.beam_size, :, dim])
            # GRU: [batch_size*self.beam_size, :, dim]
        ) = state
        batch_size = past_cand_symbols.get_shape()[0].value
        full_size = batch_size * self.beam_size
        if self.parent_refs_offsets is None:
            self.parent_refs_offsets = \
                (tf.range(full_size) // self.beam_size) * self.beam_size

        input_symbols = past_beam_symbols[:, -1]
        # [batch_size * beam_size]
        # - 0
        # _STOP 1
        # - 0
        # - 0
        stop_mask = tf.cast(tf.equal(input_symbols, self.stop_token),
                            tf.float32)

        cell_inputs = inputs

        past_cell_state = self.get_last_cell_state(past_cell_states)
        if self.use_copy and self.copy_fun != 'supervised':
            cell_output, cell_state, alignments, attns, read_copy_source = \
                self.cell(cell_inputs, past_cell_state, scope)
        elif self.use_attention:
            cell_output, cell_state, alignments, attns = \
                self.cell(cell_inputs, past_cell_state, scope)
        else:
            cell_output, cell_state = \
                self.cell(cell_inputs, past_cell_state, scope)

        # [batch_size*beam_size, num_classes]
        if self.use_copy and self.copy_fun == 'copynet':
            logprobs = tf.log(cell_output)
        else:
            W, b = self.output_project
            if self.locally_normalized:
                logprobs = tf.nn.log_softmax(tf.matmul(cell_output, W) + b)
            else:
                logprobs = tf.matmul(cell_output, W) + b

        # set the probabilities of all other symbols following the stop symbol
        # to a very small number
        stop_mask_2d = tf.expand_dims(stop_mask, 1)
        # [- - _STOP - - - ]
        # [0 0 0 0 0 0]
        # [-100 -100 0 -100 -100 -100]
        # [0 0 0 0 0 0]
        # [0 0 0 0 0 0]
        done_only_mask = tf.multiply(stop_mask_2d, self._done_mask)
        # [- - _STOP - - - ]
        # [1 1 1 1 1 1]
        # [1 1 0 1 1 1]
        # [1 1 1 1 1 1]
        # [1 1 1 1 1 1]
        zero_done_mask = tf.ones([full_size, self.num_classes]) - \
            tf.multiply(stop_mask_2d, tf.cast(tf.equal(self._done_mask, 0), tf.float32))
        logprobs = tf.add(logprobs, done_only_mask)
        logprobs = tf.multiply(logprobs, zero_done_mask)

        # length normalization
        past_beam_acc_logprobs = \
            tf.multiply(past_beam_logprobs, tf.pow(self.seq_len, self.alpha))
        logprobs_batched = tf.expand_dims(past_beam_acc_logprobs, 1) + logprobs
        float_done_mask = tf.cast(tf.not_equal(self._done_mask, 0), tf.float32)
        seq_len = tf.expand_dims(self.seq_len, 1) + \
            tf.multiply(tf.ones([full_size, 1]) - stop_mask_2d, float_done_mask)
        logprobs_batched = tf.div(logprobs_batched,
                                  tf.pow(seq_len, self.alpha))
        logprobs_batched = \
            tf.reshape(logprobs_batched, [-1, self.beam_size * self.num_classes])

        beam_logprobs, indices = tf.nn.top_k(
            #TODO: make sure it's reasonable to remove nondone mask
            tf.reshape(logprobs_batched,
                       [-1, self.beam_size * self.num_classes]),
            self.beam_size)
        beam_logprobs = tf.reshape(beam_logprobs, [-1])

        # For continuing to the next symbols
        symbols = indices % self.num_classes  # [batch_size, self.beam_size]
        parent_refs = tf.reshape(indices // self.num_classes,
                                 [-1])  # [batch_size*self.beam_size]
        parent_refs = parent_refs + self.parent_refs_offsets

        beam_symbols = tf.concat(axis=1,
                                 values=[
                                     tf.gather(past_beam_symbols, parent_refs),
                                     tf.reshape(symbols, [-1, 1])
                                 ])
        self.seq_len = tf.gather(self.seq_len, parent_refs) + \
                       tf.cast(tf.not_equal(tf.reshape(symbols, [-1]),
                                            self.stop_token), tf.float32)

        if self.use_copy and self.copy_fun != 'supervised':
            ranked_read_copy_source = tf.gather(read_copy_source, parent_refs)
        if self.use_attention:
            ranked_alignments = nest_map(
                lambda element: tf.gather(element, parent_refs), alignments)
            ranked_attns = nest_map(
                lambda element: tf.gather(element, parent_refs), attns)

        # update cell_states
        def concat_and_gather_tuple_states(pc_states, c_state):
            rc_states = (tf.concat(
                axis=1, values=[pc_states[0],
                                tf.expand_dims(c_state[0], 1)]),
                         tf.concat(axis=1,
                                   values=[
                                       pc_states[1],
                                       tf.expand_dims(c_state[1], 1)
                                   ]))
            c_states = (nest_map(
                lambda element: tf.gather(element, parent_refs), rc_states[0]),
                        nest_map(
                            lambda element: tf.gather(element, parent_refs),
                            rc_states[1]))
            return c_states

        if nest.is_sequence(cell_state):
            if self.num_layers > 1:
                ranked_cell_states = [
                    concat_and_gather_tuple_states(pc_states, c_state)
                    for pc_states, c_state in zip(past_cell_states, cell_state)
                ]
            else:
                ranked_cell_states = concat_and_gather_tuple_states(
                    past_cell_states, cell_state)
        else:
            ranked_cell_states = tf.gather(
                tf.concat(
                    axis=1,
                    values=[past_cell_states,
                            tf.expand_dims(cell_state, 1)]), parent_refs)

        # Handling for getting a done token
        logprobs_batched_3D = tf.reshape(
            logprobs_batched, [-1, self.beam_size, self.num_classes])
        logprobs_done = logprobs_batched_3D[:, :, self.stop_token]
        done_parent_refs = tf.to_int32(tf.argmax(logprobs_done, 1))
        done_parent_refs_offsets = tf.range(batch_size) * self.beam_size
        done_symbols = tf.gather(past_beam_symbols[:, -1:],
                                 done_parent_refs + done_parent_refs_offsets)

        logprobs_done_max = tf.reduce_max(logprobs_done, 1)
        cand_symbols = tf.where(logprobs_done_max > past_cand_logprobs,
                                done_symbols, past_cand_symbols)
        cand_logprobs = tf.maximum(logprobs_done_max, past_cand_logprobs)

        compound_cell_state = (cand_symbols, cand_logprobs, beam_symbols,
                               beam_logprobs, ranked_cell_states)
        ranked_cell_output = tf.gather(cell_output, parent_refs)

        if self.use_copy and self.copy_fun == 'copynet':
            return ranked_cell_output, compound_cell_state, ranked_alignments, \
                   ranked_attns, ranked_read_copy_source
        elif self.use_attention:
            return ranked_cell_output, compound_cell_state, ranked_alignments, \
                   ranked_attns
        else:
            return ranked_cell_output, compound_cell_state