Example #1
0
def dynamic_rnn(cell,
                inputs,
                sequence_length=None,
                initial_state=None,
                dtype=None,
                parallel_iterations=None,
                swap_memory=False,
                time_major=False,
                scope=None):
    assert not time_major  # TODO : to be implemented later!
    flat_inputs = flatten(inputs, 2)  # [-1, J, d]
    flat_len = None if sequence_length is None else tf.cast(
        flatten(sequence_length, 0), 'int64')

    flat_outputs, final_state = _dynamic_rnn(
        cell,
        flat_inputs,
        sequence_length=flat_len,
        initial_state=initial_state,
        dtype=dtype,
        parallel_iterations=parallel_iterations,
        swap_memory=swap_memory,
        time_major=time_major,
        scope=scope)

    outputs = reconstruct(flat_outputs, inputs, 2)
    return outputs, final_state
Example #2
0
def bidirectional_dynamic_rnn(cell_fw,
                              cell_bw,
                              inputs,
                              sequence_length=None,
                              initial_state_fw=None,
                              initial_state_bw=None,
                              dtype=None,
                              parallel_iterations=None,
                              swap_memory=False,
                              time_major=False,
                              scope=None):
    assert not time_major

    flat_inputs = flatten(inputs, 2)  # [-1, J, d]
    flat_len = None if sequence_length is None else tf.cast(
        flatten(sequence_length, 0), 'int64')

    (flat_fw_outputs, flat_bw_outputs), final_state = \
        _bidirectional_dynamic_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len,
                                   initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw,
                                   dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory,
                                   time_major=time_major, scope=scope)

    fw_outputs = reconstruct(flat_fw_outputs, inputs, 2)
    bw_outputs = reconstruct(flat_bw_outputs, inputs, 2)
    print(flat_fw_outputs, " reconstruct -> ", fw_outputs)
    # FIXME : final state is not reshaped!
    return (fw_outputs, bw_outputs), final_state
Example #3
0
def softmax(logits, mask=None, scope=None):
    with tf.name_scope(scope or "Softmax"):
        if mask is not None:
            logits = exp_mask(logits, mask)
        flat_logits = flatten(logits, 1)
        flat_out = tf.nn.softmax(flat_logits)
        out = reconstruct(flat_out, logits, 1)

        return out
Example #4
0
def bidirectional_rnn(cell_fw,
                      cell_bw,
                      inputs,
                      initial_state_fw=None,
                      initial_state_bw=None,
                      dtype=None,
                      sequence_length=None,
                      scope=None):

    flat_inputs = flatten(inputs, 2)  # [-1, J, d]
    flat_len = None if sequence_length is None else tf.cast(
        flatten(sequence_length, 0), 'int64')

    (flat_fw_outputs, flat_bw_outputs), final_state = \
        _bidirectional_rnn(cell_fw, cell_bw, flat_inputs, sequence_length=flat_len,
                           initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw,
                           dtype=dtype, scope=scope)

    fw_outputs = reconstruct(flat_fw_outputs, inputs, 2)
    bw_outputs = reconstruct(flat_bw_outputs, inputs, 2)
    # FIXME : final state is not reshaped!
    return (fw_outputs, bw_outputs), final_state
Example #5
0
    def __init__(self,
                 cell,
                 memory,
                 size,
                 mask=None,
                 controller=None,
                 mapper=None,
                 input_keep_prob=1.0,
                 is_train=None):
        """
        Early fusion attention cell: uses the (inputs, state) to control the current attention.

        :param cell:
        :param memory: [N, M, m]
        :param mask:
        :param controller: (inputs, prev_state, memory) -> memory_logits
        """
        self._cell = cell
        self._memory = memory  # u, (Q), [N, M, JX, 2d]
        self._mask = mask
        self._flat_memory = flatten(memory, 2)
        self._flat_mask = flatten(mask, 1)
        if controller is None:
            controller = AttentionCell.get_double_linear_controller(
                size, True, input_keep_prob=input_keep_prob, is_train=is_train)
            self.A_m = linear(self._memory,
                              size,
                              True,
                              scope='memory_prepare',
                              input_keep_prob=input_keep_prob,
                              is_train=is_train)  # [N * M, JX, d]
        self._controller = controller
        if mapper is None:
            mapper = AttentionCell.get_concat_mapper()
        elif mapper == 'sim':
            mapper = AttentionCell.get_sim_mapper()
        self._mapper = mapper
Example #6
0
def linear(args, output_size, bias, bias_start=0.0, scope=None, squeeze=False, wd=0.0, input_keep_prob=1.0,
           is_train=None):
    if args is None or (nest.is_sequence(args) and not args):
        raise ValueError("`args` must be specified")
    if not nest.is_sequence(args):
        args = [args]
    # args = [N, M, JX, JQ, 6d], output_size = 1
    flat_args = [flatten(arg, 1) for arg in args] # [N*M, JX, JQ, 6d]
    if input_keep_prob < 1.0:
        assert is_train is not None
        flat_args = [tf.cond(is_train, lambda: tf.nn.dropout(arg, input_keep_prob), lambda: arg)
                     for arg in flat_args]
    with tf.variable_scope(scope or "linear"):
        flat_out = _linear(flat_args, output_size, bias, bias_start=bias_start, scope=scope)
        out = reconstruct(flat_out, args[0], 1)
        if squeeze:
            out = tf.squeeze(out, [len(args[0].get_shape().as_list())-1])
        if wd:
            add_wd(wd)

    return out
Example #7
0
def build_fused_bidirectional_rnn(inputs, num_units, num_layers, inputs_length, input_keep_prob=1.0, scope=None, dtype=tf.float32):
    ''' The input of sequences should be time major
        And the Dropout is independent per time. '''
    assert num_layers > 0
    with tf.variable_scope(scope or "bidirectional_rnn"):
        inputs = flatten(inputs, 2) # [N * M, JX, d]
        current_inputs = tf.transpose(inputs, [1, 0, 2]) #[time_len, batch_size, input_size]
        for layer_id in range(num_layers):
            reverse_inputs = tf.reverse_sequence(current_inputs, inputs_length, batch_dim=1, seq_dim=0)
            fw_inputs = tf.nn.dropout(current_inputs, input_keep_prob)
            bw_inputs = tf.nn.dropout(reverse_inputs, input_keep_prob)
            fw_cell = LSTMBlockFusedCell(num_units, cell_clip=0)
            bw_cell = LSTMBlockFusedCell(num_units, cell_clip=0)
            fw_outputs, fw_final = fw_cell(fw_inputs, dtype=dtype, sequence_length=inputs_length, scope="fw_"+str(layer_id))
            bw_outputs, bw_final = bw_cell(bw_inputs, dtype=dtype, sequence_length=inputs_length, scope="bw_"+str(layer_id))
            bw_outputs = tf.reverse_sequence(bw_outputs, inputs_length, batch_dim=1, seq_dim=0)
            current_inputs = tf.concat((fw_outputs, bw_outputs), 2)
        output = tf.transpose(current_inputs, [1, 0, 2])
        output = tf.expand_dims(output, 1) # [N, M, JX, 2d]

        final_state_c = tf.concat((fw_final[0], bw_final[0]), 1, name=scope + '_final_c')
        final_state_h = tf.concat((fw_final[1], bw_final[1]), 1, name=scope + '_final_h')
        final_state = LSTMStateTuple(c=final_state_c, h=final_state_h)  # ([N, 2 * d], [N, 2 * d])
        return output, final_state
Example #8
0
    def _build_forward(self):
        config = self.config
        N, M, JX, JQ, VW, VC, d, W = \
            config.batch_size, config.max_num_sents, config.max_sent_size, \
            config.max_ques_size, config.word_vocab_size, config.char_vocab_size, config.hidden_size, \
            config.max_word_size
        print("VW:", VW, "N:", N, "M:", M, "JX:", JX, "JQ:", JQ)
        JA = config.max_answer_length
        JX = tf.shape(self.x)[2]
        JQ = tf.shape(self.q)[1]
        M = tf.shape(self.x)[1]
        print("VW:", VW, "N:", N, "M:", M, "JX:", JX, "JQ:", JQ)
        dc, dw, dco = config.char_emb_size, config.word_emb_size, config.char_out_size

        with tf.variable_scope("emb"):
            # Char-CNN Embedding
            if config.use_char_emb:
                with tf.variable_scope("emb_var"), tf.device("/cpu:0"):
                    char_emb_mat = tf.get_variable("char_emb_mat", shape=[VC, dc], dtype='float')

                with tf.variable_scope("char"):
                    Acx = tf.nn.embedding_lookup(char_emb_mat, self.cx)  # [N, M, JX, W, dc]
                    Acq = tf.nn.embedding_lookup(char_emb_mat, self.cq)  # [N, JQ, W, dc]
                    Acx = tf.reshape(Acx, [-1, JX, W, dc])
                    Acq = tf.reshape(Acq, [-1, JQ, W, dc])

                    filter_sizes = list(map(int, config.out_channel_dims.split(','))) # [100]
                    heights = list(map(int, config.filter_heights.split(','))) # [5]
                    assert sum(filter_sizes) == dco, (filter_sizes, dco) # Make sure filter channels = char_cnn_out size
                    with tf.variable_scope("conv"):
                        xx = multi_conv1d(Acx, filter_sizes, heights, "VALID",  self.is_train, config.keep_prob, scope="xx")
                        if config.share_cnn_weights:
                            tf.get_variable_scope().reuse_variables()
                            qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="xx")
                        else:
                            qq = multi_conv1d(Acq, filter_sizes, heights, "VALID", self.is_train, config.keep_prob, scope="qq")
                        xx = tf.reshape(xx, [-1, M, JX, dco])
                        qq = tf.reshape(qq, [-1, JQ, dco])

            # Word Embedding
            if config.use_word_emb:
                with tf.variable_scope("emb_var") as scope, tf.device("/cpu:0"):
                    if config.mode == 'train':
                        word_emb_mat = tf.get_variable("word_emb_mat", dtype='float', shape=[VW, dw], initializer=get_initializer(config.emb_mat))
                    else:
                        word_emb_mat = tf.get_variable("word_emb_mat", shape=[VW, dw], dtype='float')
                    tf.get_variable_scope().reuse_variables()
                    self.word_emb_scope = scope
                    if config.use_glove_for_unk:
                        word_emb_mat = tf.concat([word_emb_mat, self.new_emb_mat], 0)

                with tf.name_scope("word"):
                    Ax = tf.nn.embedding_lookup(word_emb_mat, self.x)  # [N, M, JX, d]
                    Aq = tf.nn.embedding_lookup(word_emb_mat, self.q)  # [N, JQ, d]
                    self.tensor_dict['x'] = Ax
                    self.tensor_dict['q'] = Aq
                # Concat Char-CNN Embedding and Word Embedding
                if config.use_char_emb:
                    xx = tf.concat([xx, Ax], 3)  # [N, M, JX, di]
                    qq = tf.concat([qq, Aq], 2)  # [N, JQ, di]
                else:
                    xx = Ax
                    qq = Aq

            # exact match
            if config.use_exact_match: # TODO: What does it mean?
                emx = tf.expand_dims(tf.cast(self.emx, tf.float32), -1)
                xx = tf.concat([xx, emx], 3)  # [N, M, JX, di+1]
                emq = tf.expand_dims(tf.cast(self.emq, tf.float32), -1)
                qq = tf.concat([qq, emq], 2)  # [N, JQ, di+1]


        # 2 layer highway network on Concat Embedding
        if config.highway:
            with tf.variable_scope("highway"):
                xx = highway_network(xx, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train)
                tf.get_variable_scope().reuse_variables()
                qq = highway_network(qq, config.highway_num_layers, True, wd=config.wd, is_train=self.is_train)

        self.tensor_dict['xx'] = xx
        self.tensor_dict['qq'] = qq

        # Bidirection-LSTM (3rd layer on paper)
        cell = GRUCell(d) if config.GRU else BasicLSTMCell(d, state_is_tuple=True)
        d_cell = SwitchableDropoutWrapper(cell, self.is_train, input_keep_prob=config.input_keep_prob)
        x_len = tf.reduce_sum(tf.cast(self.x_mask, 'int32'), 2)  # [N, M]
        q_len = tf.reduce_sum(tf.cast(self.q_mask, 'int32'), 1)  # [N]
        flat_x_len = flatten(x_len, 0)  # [N * M]

        with tf.variable_scope("prepro"):
            if config.use_fused_lstm: #yes
                with tf.variable_scope("u1"):
                    fw_inputs = tf.transpose(qq, [1, 0, 2]) #[time_len, batch_size, input_size]
                    bw_inputs = tf.reverse_sequence(fw_inputs, q_len, batch_dim=1, seq_dim=0)
                    fw_inputs = tf.nn.dropout(fw_inputs, config.input_keep_prob)
                    bw_inputs = tf.nn.dropout(bw_inputs, config.input_keep_prob)
                    prep_fw_cell = LSTMBlockFusedCell(d, cell_clip=0)
                    prep_bw_cell = LSTMBlockFusedCell(d, cell_clip=0)
                    fw_outputs, fw_final = prep_fw_cell(fw_inputs, dtype=tf.float32, sequence_length=q_len, scope="fw")
                    bw_outputs, bw_final = prep_bw_cell(bw_inputs, dtype=tf.float32, sequence_length=q_len, scope="bw")
                    bw_outputs = tf.reverse_sequence(bw_outputs, q_len, batch_dim=1, seq_dim = 0)
                    current_inputs = tf.concat((fw_outputs, bw_outputs), 2)
                    output = tf.transpose(current_inputs, [1, 0, 2])
                    u = output
                flat_xx = flatten(xx, 2)  # [N * M, JX, d]
                if config.share_lstm_weights: # Yes
                    tf.get_variable_scope().reuse_variables()
                    with tf.variable_scope("u1"):
                        fw_inputs = tf.transpose(flat_xx, [1, 0, 2]) #[time_len, batch_size, input_size]
                        bw_inputs = tf.reverse_sequence(fw_inputs, flat_x_len, batch_dim=1, seq_dim=0)
                        # fw_inputs = tf.nn.dropout(fw_inputs, config.input_keep_prob)
                        # bw_inputs = tf.nn.dropout(bw_inputs, config.input_keep_prob)
                        fw_outputs, fw_final = prep_fw_cell(fw_inputs, dtype=tf.float32, sequence_length=flat_x_len, scope="fw")
                        bw_outputs, bw_final = prep_bw_cell(bw_inputs, dtype=tf.float32, sequence_length=flat_x_len, scope="bw")
                        bw_outputs = tf.reverse_sequence(bw_outputs, flat_x_len, batch_dim=1, seq_dim=0)
                        current_inputs = tf.concat((fw_outputs, bw_outputs), 2)
                        output = tf.transpose(current_inputs, [1, 0, 2])
                else: # No
                    with tf.variable_scope("h1"):
                        fw_inputs = tf.transpose(flat_xx, [1, 0, 2]) #[time_len, batch_size, input_size]
                        bw_inputs = tf.reverse_sequence(fw_inputs, flat_x_len, batch_dim=1, seq_dim=0)
                        # fw_inputs = tf.nn.dropout(fw_inputs, config.input_keep_prob)
                        # bw_inputs = tf.nn.dropout(bw_inputs, config.input_keep_prob)
                        prep_fw_cell = LSTMBlockFusedCell(d, cell_clip=0)
                        prep_bw_cell = LSTMBlockFusedCell(d, cell_clip=0)
                        fw_outputs, fw_final = prep_fw_cell(fw_inputs, dtype=tf.float32, sequence_length=flat_x_len, scope="fw")
                        bw_outputs, bw_final = prep_bw_cell(bw_inputs, dtype=tf.float32, sequence_length=flat_x_len, scope="bw")
                        bw_outputs = tf.reverse_sequence(bw_outputs, flat_x_len, batch_dim=1, seq_dim=0)
                        current_inputs = tf.concat((fw_outputs, bw_outputs), 2)
                        output = tf.transpose(current_inputs, [1, 0, 2])
                h = tf.expand_dims(output, 1) # [N, M, JX, 2d]
            else:
                (fw_u, bw_u), _ = bidirectional_dynamic_rnn(d_cell, d_cell, qq, q_len, dtype='float', scope='u1')  # [N, J, d], [N, d]
                u = tf.concat([fw_u, bw_u], 2)
                if config.share_lstm_weights:
                    tf.get_variable_scope().reuse_variables()
                    (fw_h, bw_h), _ = bidirectional_dynamic_rnn(cell, cell, xx, x_len, dtype='float', scope='u1')  # [N, M, JX, 2d]
                    h = tf.concat([fw_h, bw_h], 3)  # [N, M, JX, 2d]
                else:
                    (fw_h, bw_h), _ = bidirectional_dynamic_rnn(cell, cell, xx, x_len, dtype='float', scope='h1')  # [N, M, JX, 2d]
                    h = tf.concat([fw_h, bw_h], 3)  # [N, M, JX, 2d]
            self.tensor_dict['u'] = u # hidden state of Q = u
            self.tensor_dict['h'] = h # hidden state of C = h

        # Attention Flow Layer (4th layer on paper)
        with tf.variable_scope("main"):
            if config.dynamic_att:
                p0 = h
                u = tf.reshape(tf.tile(tf.expand_dims(u, 1), [1, M, 1, 1]), [N * M, JQ, 2 * d])
                q_mask = tf.reshape(tf.tile(tf.expand_dims(self.q_mask, 1), [1, M, 1]), [N * M, JQ])
                first_cell = AttentionCell(cell, u, size=d, mask=q_mask, mapper='sim',
                                           input_keep_prob=self.config.input_keep_prob, is_train=self.is_train)
            else:
                p0 = attention_layer(config, self.is_train, h, u, h_mask=self.x_mask, u_mask=self.q_mask, scope="p0", tensor_dict=self.tensor_dict)
                first_cell = d_cell # a GRU cell with dropout wrapper
            tp0 = p0 # Output of Attention layer

        # Modeling layer (5th layer on paper)
        with tf.variable_scope('modeling_layer'):
            if config.use_fused_lstm:
                g1, encoder_state_final = build_fused_bidirectional_rnn(inputs=p0,
                                                                        num_units=config.hidden_size,
                                                                        num_layers=config.num_modeling_layers,
                                                                        inputs_length=flat_x_len,
                                                                        input_keep_prob=config.input_keep_prob,
                                                                        scope='modeling_layer_g')

            else:
                for layer_idx in range(config.num_modeling_layers-1):
                    (fw_g0, bw_g0), _ = bidirectional_dynamic_rnn(first_cell, first_cell, p0, x_len,
                                                                  dtype='float', scope="g_{}".format(layer_idx))  # [N, M, JX, 2d]
                    p0 = tf.concat([fw_g0, bw_g0], 3)
                (fw_g1, bw_g1), (fw_s_f, bw_s_f) = bidirectional_dynamic_rnn(first_cell, first_cell, p0, x_len,
                                                                             dtype='float', scope='g1')  # [N, M, JX, 2d]
                g1 = tf.concat([fw_g1, bw_g1], 3)  # [N, M, JX, 2d]

        # Self match layer
        if config.use_self_match:
            s0 = tf.reshape(g1, [N * M, JX, 2 * d])                     # [N * M, JX, 2d]
            x_mask = tf.reshape(self.x_mask, [N * M, JX])               # [N * M, JX]
            if config.use_static_self_match:
                with tf.variable_scope("StaticSelfMatch"):              # implemented follow r-net section 3.3
                    W_x_Vj = tf.contrib.layers.fully_connected(         # [N * M, JX, d]
                        s0, int(d / 2), scope='row_first',
                        activation_fn=None, biases_initializer=None
                    )
                    W_x_Vt = tf.contrib.layers.fully_connected(         # [N * M, JX, d]
                        s0, int(d / 2), scope='col_first',
                        activation_fn=None, biases_initializer=None
                    )
                    sum_rc = tf.add(                                    # [N * M, JX, JX, d]
                        tf.expand_dims(W_x_Vj, 1),
                        tf.expand_dims(W_x_Vt, 2)
                    )
                    v = tf.get_variable('second', shape=[1, 1, 1, int(d / 2)], dtype=tf.float32)
                    Sj = tf.reduce_sum(tf.multiply(v, tf.tanh(sum_rc)), -1)     # [N * M, JX, JX]
                    Ai = softmax(Sj, mask = tf.expand_dims(x_mask, 1))          # [N * M, JX, JX]
                    Ai = tf.expand_dims(Ai, -1)                                 # [N * M, JX, JX, 1]
                    Vi = tf.expand_dims(s0, 1)                                  # [N * M, 1, JX, 2d]
                    Ct = tf.reduce_sum(                                         # [N * M, JX, 2d]
                        tf.multiply(Ai, Vi),
                        axis = 2
                    )
                    inputs_Vt_Ct = tf.concat([s0, Ct], 2)                       # [N * M, JX, 4d]
                    if config.use_fused_lstm:
                        fw_inputs = tf.transpose(inputs_Vt_Ct, [1, 0, 2])  # [time_len, batch_size, input_size]
                        bw_inputs = tf.reverse_sequence(fw_inputs, flat_x_len, batch_dim=1, seq_dim=0)
                        fw_inputs = tf.nn.dropout(fw_inputs, config.input_keep_prob)
                        bw_inputs = tf.nn.dropout(bw_inputs, config.input_keep_prob)
                        prep_fw_cell = LSTMBlockFusedCell(d, cell_clip=0)
                        prep_bw_cell = LSTMBlockFusedCell(d, cell_clip=0)
                        fw_outputs, fw_s_f = prep_fw_cell(fw_inputs, dtype=tf.float32, sequence_length=flat_x_len,
                                                            scope="fw")
                        bw_outputs, bw_s_f = prep_bw_cell(bw_inputs, dtype=tf.float32, sequence_length=flat_x_len,
                                                            scope="bw")
                        fw_s_f = LSTMStateTuple(c=fw_s_f[0], h=fw_s_f[1])
                        bw_s_f = LSTMStateTuple(c=bw_s_f[0], h=bw_s_f[1])
                        bw_outputs = tf.reverse_sequence(bw_outputs, flat_x_len, batch_dim=1, seq_dim=0)
                        current_inputs = tf.concat((fw_outputs, bw_outputs), 2)
                        s1 = tf.transpose(current_inputs, [1, 0, 2])
                    else:
                        (fw_s, bw_s), (fw_s_f, bw_s_f) = bidirectional_dynamic_rnn(first_cell, first_cell, inputs_Vt_Ct,
                                                                                   flat_x_len, dtype='float',
                                                                                   scope='s')  # [N, M, JX, 2d]
                        s1 = tf.concat([fw_s, bw_s], 2)  # [N * M, JX, 2d], M == 1
            else:
                with tf.variable_scope("DynamicSelfMatch"):
                    first_cell = AttentionCell(cell, s0, size=d, mask=x_mask, is_train=self.is_train)
                    (fw_s, bw_s), (fw_s_f, bw_s_f) = bidirectional_dynamic_rnn(first_cell, first_cell, s0, x_len,
                                                                               dtype='float', scope='s')  # [N, M, JX, 2d]
                    s1 = tf.concat([fw_s, bw_s], 2)  # [N * M, JX, 2d], M == 1
            g1 = tf.expand_dims(s1, 1) # [N, M, JX, 2d]

        # prepare for PtrNet
        encoder_output = g1  # [N, M, JX, 2d]
        encoder_output = tf.expand_dims(tf.cast(self.x_mask, tf.float32), -1) * encoder_output  # [N, M, JX, 2d]

        if config.use_self_match or not config.use_fused_lstm:
            if config.GRU:
                encoder_state_final = tf.concat((fw_s_f, bw_s_f), 1, name='encoder_concat')
            else:
                if isinstance(fw_s_f, LSTMStateTuple):
                    encoder_state_c = tf.concat(
                        (fw_s_f.c, bw_s_f.c), 1, name='encoder_concat_c')
                    encoder_state_h = tf.concat(
                        (fw_s_f.h, bw_s_f.h), 1, name='encoder_concat_h')
                    encoder_state_final = LSTMStateTuple(c=encoder_state_c, h=encoder_state_h)
                elif isinstance(fw_s_f, tf.Tensor):
                    encoder_state_final = tf.concat((fw_s_f, bw_s_f), 1, name='encoder_concat')
                else:
                    encoder_state_final = None
                    tf.logging.error("encoder_state_final not set")

        print("encoder_state_final:", encoder_state_final)

        with tf.variable_scope("output"):
            # eos_symbol = config.eos_symbol
            # next_symbol = config.next_symbol

            tf.assert_equal(M, 1)  # currently dynamic M is not supported, thus we assume M==1
            answer_string = tf.placeholder(
                shape=(N, 1, JA + 1),
                dtype=tf.int32,
                name='answer_string'
            )  # [N, M, JA + 1]
            answer_string_mask = tf.placeholder(
                shape=(N, 1, JA + 1),
                dtype=tf.bool,
                name='answer_string_mask'
            )  # [N, M, JA + 1]
            answer_string_length = tf.placeholder(
                shape=(N, 1),
                dtype=tf.int32,
                name='answer_string_length',
            ) # [N, M]
            self.tensor_dict['answer_string'] = answer_string
            self.tensor_dict['answer_string_mask'] = answer_string_mask
            self.tensor_dict['answer_string_length'] = answer_string_length
            self.answer_string = answer_string
            self.answer_string_mask = answer_string_mask
            self.answer_string_length = answer_string_length

            answer_string_flattened = tf.reshape(answer_string, [N * M, JA + 1])
            self.answer_string_flattened = answer_string_flattened  # [N * M, JA+1]
            print("answer_string_flattened:", answer_string_flattened)

            answer_string_length_flattened = tf.reshape(answer_string_length, [N * M])
            self.answer_string_length_flattened = answer_string_length_flattened  # [N * M]
            print("answer_string_length_flattened:", answer_string_length_flattened)

            decoder_cell = GRUCell(2 * d) if config.GRU else BasicLSTMCell(2 * d, state_is_tuple=True)

            with tf.variable_scope("Decoder"):
                decoder_train_logits = ptr_decoder(decoder_cell,
                                                   tf.reshape(tp0, [N * M, JX, 2 * d]),  # [N * M, JX, 2d]
                                                   tf.reshape(encoder_output, [N * M, JX, 2 * d]),  # [N * M, JX, 2d]
                                                   flat_x_len,
                                                   encoder_final_state=encoder_state_final,
                                                   max_encoder_length=config.sent_size_th,
                                                   decoder_output_length=answer_string_length_flattened,  # [N * M]
                                                   batch_size=N,  # N * M (M=1)
                                                   attention_proj_dim=self.config.decoder_proj_dim,
                                                   scope='ptr_decoder')  # [batch_size, dec_len*, enc_seq_len + 1]

                self.decoder_train_logits = decoder_train_logits
                print("decoder_train_logits:", decoder_train_logits)
                self.decoder_train_softmax = tf.nn.softmax(self.decoder_train_logits)
                self.decoder_inference = tf.argmax(decoder_train_logits, axis=2,
                                                   name='decoder_inference')  # [N, JA + 1]

            self.yp = tf.ones([N, M, JX], dtype=tf.int32) * -1
            self.yp2 = tf.ones([N, M, JX], dtype=tf.int32) * -1