Exemple #1
0
 def input_gru_step(self, x, prev_h):
     '''
     Call GRU_update with self parameters
     :param x: input for the GRU update
     :param prev_h: previous state
     :return next step for the input GRU
     '''
     return nn_utils.GRU_update(prev_h, x, self.W_inp_res_in,
                                self.W_inp_res_hid, self.b_inp_res,
                                self.W_inp_upd_in, self.W_inp_upd_hid,
                                self.b_inp_upd, self.W_inp_hid_in,
                                self.W_inp_hid_hid, self.b_inp_hid)
Exemple #2
0
    def new_episode_step(self, ct, g, prev_h):
        '''
        Compute the h_t^i for the MemUpdate Mechanism
        :param ct: facts representation
        :param g: weights of the gates g^i (given by the attention mechanism)
        :param prev_h: previous state of the Mem GRU (h_t-1^i)
        :return h_t^i: next state
        '''
        gru = nn_utils.GRU_update(prev_h, ct, self.W_mem_res_in,
                                  self.W_mem_res_hid, self.b_mem_res,
                                  self.W_mem_upd_in, self.W_mem_upd_hid,
                                  self.b_mem_upd, self.W_mem_hid_in,
                                  self.W_mem_hid_hid, self.b_mem_hid)

        h = g * gru + (1 - g) * prev_h
        return h
Exemple #3
0
    def __init__(self, babi_train_raw, babi_test_raw, word2vec,
                 word_vector_size, dim, mode, answer_module, answer_step_nbr,
                 input_mask_mode, memory_hops, l2, normalize_attention,
                 max_input_size, **kwargs):
        '''
        Build the DMN
        :param babi_train_raw: train dataset
        :param babi_test_raw: test dataset
        :param word2vec: a dictionary containing the word embeddings TODO: Check if right
        :param word_vector_size: dimension of the word embeddings (50,100,200,300)
        :param dim: number of hidden units in input module GRU
        :param mode: train or test mode
        :param answer_module: answer module type: feedforward or recurrent
        :param input_mask_mode: input_mask_mode: word or sentence
        :param memory_hops: memory GRU steps
        :param l2: L2 regularization
        :param normalize_attention: enable softmax on attention vector
        :param **kwargs:
        '''

        print("==> not used params in DMN class:", kwargs.keys())
        self.type = "pointer"
        self.vocab = {}
        self.ivocab = {}

        print(max_input_size)

        #save params
        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim  #number of hidden units in input layer GRU
        self.pointer_dim = max_input_size  #maximal size for the input, used as hyperparameter
        self.mode = mode
        self.answer_module = answer_module
        self.answer_step_nbr = answer_step_nbr
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        self.l2 = l2
        self.normalize_attention = normalize_attention

        self.train_input, self.train_q, self.train_answer, self.train_input_mask, self.train_pointers_s, self.train_pointers_e = self._process_input(
            babi_train_raw)
        self.test_input, self.test_q, self.test_answer, self.test_input_mask, self.test_pointers_s, self.test_pointers_e = self._process_input(
            babi_test_raw)
        self.vocab_size = len(self.vocab)

        self.input_var = T.matrix('input_var')
        self.q_var = T.matrix('question_var')
        self.answer_var = T.ivector('answer_var')
        self.input_mask_var = T.ivector('input_mask_var')
        self.pointers_s_var = T.ivector('pointers_s_var')
        self.pointers_e_var = T.ivector('pointer_e_var')

        print("==> building input module")
        self.W_inp_res_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_inp_upd_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_inp_hid_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))
        #TODO why 3 different set of weights & bias?

        #This does some loop
        inp_c_history, _ = theano.scan(fn=self.input_gru_step,
                                       sequences=self.input_var,
                                       outputs_info=T.zeros_like(
                                           self.b_inp_hid))

        #in case of multiple sentences, only keep the hidden states which index match the <eos> char
        self.inp_c = inp_c_history.take(self.input_mask_var, axis=0)

        #This seems to be the memory.
        self.q_q, _ = theano.scan(fn=self.input_gru_step,
                                  sequences=self.q_var,
                                  outputs_info=T.zeros_like(self.b_inp_hid))

        self.q_q = self.q_q[-1]  #take only last elem

        print("==> creating parameters for memory module")
        self.W_mem_res_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_upd_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_hid_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        #Attnetion mechanisms 2 layer FFNN weights & bias
        self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_1 = nn_utils.normal_param(std=0.1,
                                         shape=(self.dim, 7 * self.dim + 2))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim, ))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1, ))

        print(
            "==> building episodic memory module (fixed number of steps: %d)" %
            self.memory_hops)
        memory = [self.q_q.copy()]  #So q_q is memory initialization
        for iter in range(1, self.memory_hops + 1):
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(
                nn_utils.GRU_update(memory[iter - 1], current_episode,
                                    self.W_mem_res_in, self.W_mem_res_hid,
                                    self.b_mem_res, self.W_mem_upd_in,
                                    self.W_mem_upd_hid, self.b_mem_upd,
                                    self.W_mem_hid_in, self.W_mem_hid_hid,
                                    self.b_mem_hid))

        self.last_mem = memory[-1]

        print("==> building answer module")
        self.Ws_p = nn_utils.normal_param(
            std=0.1,
            shape=(self.pointer_dim,
                   self.dim))  #shape must be size_input * mem_size = self.dim
        self.We_p = nn_utils.normal_param(std=0.1,
                                          shape=(self.pointer_dim, self.dim))
        self.Wh_p = nn_utils.normal_param(std=0.1,
                                          shape=(self.pointer_dim, self.dim))
        self.Ws_pr = nn_utils.normal_param(
            std=0.1,
            shape=(self.pointer_dim,
                   self.dim))  #shape must be size_input * mem_size = self.dim
        self.We_pr = nn_utils.normal_param(std=0.1,
                                           shape=(self.pointer_dim, self.dim))
        self.Wh_pr = nn_utils.normal_param(std=0.1,
                                           shape=(self.pointer_dim, self.dim))

        self.Psp = nn_utils.softmax(T.dot(
            self.Ws_p, self.last_mem))  #size must be == size_input
        self.Pepr = nn_utils.softmax(T.dot(self.We_pr, self.last_mem))

        #TODO:
        self.start_idx = T.argmax(self.Psp)
        self.end_idxr = T.argmax(self.Pepr)

        self.start_idx_state = inp_c_history[
            self.
            start_idx]  #must be hidden state idx idx_max_val(Psp)  self.last_mem#
        self.end_idx_state = inp_c_history[self.end_idxr]
        #temp1 = T.dot(self.We_p, self.last_mem)
        #temp2 = T.dot(self.Wh_p, self.start_idx_state)
        #temp3 = temp1 + temp2
        self.Pep = nn_utils.softmax(
            T.dot(self.We_p, self.last_mem) + T.dot(
                self.Wh_p, self.start_idx_state))  #size must be == size_input
        self.Pspr = nn_utils.softmax(
            T.dot(self.Ws_pr, self.last_mem) +
            T.dot(self.Wh_pr, self.end_idx_state))

        Ps = (self.Psp + self.Pspr) / 2
        Pe = (self.Pep + self.Pepr) / 2
        self.start_idxr = T.argmax(self.Pspr)
        self.end_idx = T.argmax(self.Pep)

        self.start_idx_f = T.argmax(Ps)  #(self.start_idx + self.start_idxr)/2
        self.end_idx_f = T.argmax(Pe)  #(self.end_idx + self.end_idxr)/2

        #multiple_answers = []
        #bboole = T.lt(self.start_idx_f, self.end_idx_f)

        #trange = ifelse(bboole, T.arange(self.start_idx_f, self.end_idx_f), T.arange(self.start_idx_f - 1, self.start_idx_f))

        #        self.W_a = nn_utils.normal_param(std=0.1, shape=(self.vocab_size, self.dim))
        #
        #        if self.answer_module == 'recurrent':
        #            self.W_ans_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim + self.dim + self.vocab_size))
        #            self.W_ans_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        #            self.b_ans_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        #
        #            self.W_ans_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim + self.dim + self.vocab_size))
        #            self.W_ans_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        #            self.b_ans_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        #
        #            self.W_ans_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim + self.dim + self.vocab_size))
        #            self.W_ans_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        #            self.b_ans_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        #
        #            def answer_step(prev_a, prev_y):
        #                a = nn_utils.GRU_update(prev_a, T.concatenate([prev_y, self.q_q, self.last_mem]),
        #                                  self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res,
        #                                  self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
        #                                  self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid)
        #
        #                y = nn_utils.softmax(T.dot(self.W_a, a))
        #                return [a, y]
        #
        #            # TODO: add conditional ending
        #            dummy_ = theano.shared(np.zeros((self.vocab_size, ), dtype=floatX))
        #            results, updates = theano.scan(fn=answer_step,
        #                outputs_info=[self.last_mem, T.zeros_like(dummy_)],
        #                n_steps=self.answer_step_nbr)
        #
        #            self.multiple_predictions = results[1] #don't get the memory (i.e. a)
        #
        #
        #        else:
        #            raise Exception("invalid answer_module")

        print("==> collecting all parameters")
        self.params = [
            self.W_inp_res_in, self.W_inp_res_hid, self.b_inp_res,
            self.W_inp_upd_in, self.W_inp_upd_hid, self.b_inp_upd,
            self.W_inp_hid_in, self.W_inp_hid_hid, self.b_inp_hid,
            self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res,
            self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
            self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid, self.W_b,
            self.W_1, self.W_2, self.b_1, self.b_2, self.Ws_p, self.We_p,
            self.Wh_p, self.Ws_pr, self.We_pr, self.Wh_pr
        ]

        #        if self.answer_module == 'recurrent':
        #            self.params = self.params + [self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res,
        #                              self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
        #                              self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid]
        #
        #
        print("==> building loss layer and computing updates")
        #        def temp_loss(curr_pred, curr_ans, loss):
        #            temp = T.nnet.categorical_crossentropy(curr_pred.dimshuffle("x",0),T.stack([curr_ans]))[0]
        #            return loss + temp
        #
        #        outputs, updates = theano.scan(fn=temp_loss,
        #                                            sequences=[self.multiple_predictions, self.answer_var],
        #                                            outputs_info = [np.float64(0.0)],
        #                                            n_steps=self.answer_step_nbr)

        loss_start = T.nnet.categorical_crossentropy(
            Ps.dimshuffle("x", 0), T.stack([self.pointers_s_var[0]]))[0]
        loss_end = T.nnet.categorical_crossentropy(
            Pe.dimshuffle("x", 0), T.stack([self.pointers_e_var[0]]))[0]
        #loss_1 = Ps
        #        def temp_loss(curr_idx, curr_ans, loss):
        #            curr_pred = self.input_var[curr_idx]
        #            temp = T.nnet.catergorical_crossentropy(curr_pred, curr_ans)[0]
        #            return loss + temp
        #
        #        outputs, udpates = theano.scan(fn=temp_loss,
        #                                       sequences = [answers_range, self.answer_var],
        #                                        outputs_info = [np.float64(0.0)],
        #                                        n_steps = ???)

        #        self.loss_ce = outputs[-1]
        #temp1 = (self.end_idx_f - self.pointers_e_var)
        #temp2 = T.abs_(temp1) #* temp1
        #temp3 = (self.start_idx_f)# - self.pointers_s_var)
        #temp4 = T.abs_(temp3) #* temp3
        self.loss_ce = loss_start + loss_end  #(temp2 + temp4)
        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0

        self.loss = self.loss_ce + self.loss_l2

        updates = lasagne.updates.adadelta(self.loss, self.params)

        if self.mode == 'train':
            print("==> compiling train_fn")
            self.train_fn = theano.function(
                inputs=[
                    self.input_var, self.q_var, self.input_mask_var,
                    self.pointers_s_var, self.pointers_e_var
                ],
                outputs=[self.start_idx_f, self.end_idx_f, self.loss],
                updates=updates,
                allow_input_downcast=True)
        if self.mode != 'minitest':
            print("==> compiling test_fn")
            self.test_fn = theano.function(inputs=[
                self.input_var, self.q_var, self.input_mask_var,
                self.pointers_s_var, self.pointers_e_var
            ],
                                           outputs=[
                                               self.start_idx_f,
                                               self.end_idx_f, self.loss,
                                               self.inp_c, self.q_q
                                           ],
                                           allow_input_downcast=True)

        if self.mode == 'minitest':
            print("==> compiling minitest_fn")
            self.minitest_fn = theano.function(
                inputs=[
                    self.input_var, self.q_var, self.input_mask_var,
                    self.pointers_s_var, self.pointers_e_var
                ],
                outputs=[self.start_idx_f, self.end_idx_f])
Exemple #4
0
    def __init__(self, babi_train_raw, babi_test_raw, word2vec,
                 word_vector_size, dim, mode, answer_module, answer_step_nbr,
                 input_mask_mode, memory_hops, l2, normalize_attention,
                 **kwargs):
        '''
        Build the DMN
        :param babi_train_raw: train dataset
        :param babi_test_raw: test dataset
        :param word2vec: a dictionary containing the word embeddings TODO: Check if right
        :param word_vector_size: dimension of the word embeddings (50,100,200,300)
        :param dim: number of hidden units in input module GRU
        :param mode: train or test mode
        :param answer_module: answer module type: feedforward or recurrent
        :param input_mask_mode: input_mask_mode: word or sentence
        :param memory_hops: memory GRU steps
        :param l2: L2 regularization
        :param normalize_attention: enable softmax on attention vector
        :param **kwargs:
        '''

        print("==> not used params in DMN class:", kwargs.keys())
        self.type = "multiple"
        self.vocab = {}
        self.ivocab = {}

        #save params
        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.mode = mode
        self.answer_module = answer_module
        #TODO: add check of inputs
        if (answer_step_nbr < 1):
            raise Exception('The number of answer step must be greater than 0')
        self.answer_step_nbr = answer_step_nbr
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        self.l2 = l2
        self.normalize_attention = normalize_attention

        self.train_input, self.train_q, self.train_answer, self.train_input_mask = self._process_input(
            babi_train_raw)
        self.test_input, self.test_q, self.test_answer, self.test_input_mask = self._process_input(
            babi_test_raw)
        self.vocab_size = len(self.vocab)

        self.input_var = T.matrix('input_var')
        self.q_var = T.matrix('question_var')
        self.answer_var = T.ivector('answer_var')
        self.input_mask_var = T.ivector('input_mask_var')

        print("==> building input module")
        self.W_inp_res_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_inp_upd_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_inp_hid_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))
        #TODO why 3 different set of weights & bias?

        #This does some loop
        inp_c_history, _ = theano.scan(fn=self.input_gru_step,
                                       sequences=self.input_var,
                                       outputs_info=T.zeros_like(
                                           self.b_inp_hid))

        self.inp_c = inp_c_history.take(self.input_mask_var, axis=0)

        #This seems to be the memory.
        self.q_q, _ = theano.scan(fn=self.input_gru_step,
                                  sequences=self.q_var,
                                  outputs_info=T.zeros_like(self.b_inp_hid))

        self.q_q = self.q_q[-1]  #take only last elem

        print("==> creating parameters for memory module")
        self.W_mem_res_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_upd_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_hid_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        #Attnetion mechanisms 2 layer FFNN weights & bias
        self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_1 = nn_utils.normal_param(std=0.1,
                                         shape=(self.dim, 7 * self.dim + 2))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim, ))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1, ))

        print(
            "==> building episodic memory module (fixed number of steps: %d)" %
            self.memory_hops)
        memory = [self.q_q.copy()]  #So q_q is memory initialization
        for iter in range(1, self.memory_hops + 1):
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(
                nn_utils.GRU_update(memory[iter - 1], current_episode,
                                    self.W_mem_res_in, self.W_mem_res_hid,
                                    self.b_mem_res, self.W_mem_upd_in,
                                    self.W_mem_upd_hid, self.b_mem_upd,
                                    self.W_mem_hid_in, self.W_mem_hid_hid,
                                    self.b_mem_hid))

        self.last_mem = memory[-1]

        print("==> building answer module")
        self.W_a = nn_utils.normal_param(std=0.1,
                                         shape=(self.vocab_size, self.dim))

        if self.answer_module == 'recurrent':
            self.W_ans_res_in = nn_utils.normal_param(
                std=0.1,
                shape=(self.dim, self.dim + self.dim + self.vocab_size))
            self.W_ans_res_hid = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
            self.b_ans_res = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

            self.W_ans_upd_in = nn_utils.normal_param(
                std=0.1,
                shape=(self.dim, self.dim + self.dim + self.vocab_size))
            self.W_ans_upd_hid = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
            self.b_ans_upd = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

            self.W_ans_hid_in = nn_utils.normal_param(
                std=0.1,
                shape=(self.dim, self.dim + self.dim + self.vocab_size))
            self.W_ans_hid_hid = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
            self.b_ans_hid = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

            def answer_step(prev_a, prev_y):
                a = nn_utils.GRU_update(
                    prev_a, T.concatenate([prev_y, self.q_q, self.last_mem]),
                    self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res,
                    self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                    self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid)

                y = nn_utils.softmax(T.dot(self.W_a, a))
                return [a, y]

            # TODO: add conditional ending
            dummy_ = theano.shared(np.zeros((self.vocab_size, ), dtype=floatX))
            results, updates = theano.scan(
                fn=answer_step,
                outputs_info=[self.last_mem,
                              T.zeros_like(dummy_)],
                n_steps=self.answer_step_nbr)

            self.multiple_predictions = results[
                1]  #don't get the memory (i.e. a)

        else:
            raise Exception("invalid answer_module")

        print("==> collecting all parameters")
        self.params = [
            self.W_inp_res_in, self.W_inp_res_hid, self.b_inp_res,
            self.W_inp_upd_in, self.W_inp_upd_hid, self.b_inp_upd,
            self.W_inp_hid_in, self.W_inp_hid_hid, self.b_inp_hid,
            self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res,
            self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
            self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid, self.W_b,
            self.W_1, self.W_2, self.b_1, self.b_2, self.W_a
        ]

        if self.answer_module == 'recurrent':
            self.params = self.params + [
                self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res,
                self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid
            ]

        print("==> building loss layer and computing updates")

        def temp_loss(curr_pred, curr_ans, loss):
            temp = T.nnet.categorical_crossentropy(
                curr_pred.dimshuffle("x", 0),
                T.stack(curr_ans))[0]  #T.stack([curr_ans]))[0]
            return loss + temp

        outputs, updates = theano.scan(
            fn=temp_loss,
            sequences=[self.multiple_predictions, self.answer_var],
            outputs_info=[np.float64(0.0)],
            n_steps=self.answer_step_nbr)

        self.loss_ce = outputs[-1]
        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0

        self.loss = self.loss_ce + self.loss_l2

        updates = lasagne.updates.adadelta(self.loss, self.params)

        if self.mode == 'train':
            print("==> compiling train_fn")
            self.train_fn = theano.function(
                inputs=[
                    self.input_var, self.q_var,
                    T.cast(self.answer_var, 'int32'), self.input_mask_var
                ],
                outputs=[self.multiple_predictions, self.loss],
                updates=updates,
                allow_input_downcast=True)
        if self.mode != 'minitest':
            print("==> compiling test_fn")
            self.test_fn = theano.function(inputs=[
                self.input_var, self.q_var,
                T.cast(self.answer_var, 'int32'), self.input_mask_var
            ],
                                           outputs=[
                                               self.multiple_predictions,
                                               self.loss, self.inp_c, self.q_q,
                                               self.last_mem
                                           ],
                                           allow_input_downcast=True)

        if self.mode == 'minitest':
            print("==> compiling minitest_fn")
            self.minitest_fn = theano.function(
                inputs=[self.input_var, self.q_var, self.input_mask_var],
                outputs=[self.multiple_predictions])