示例#1
0
    def __init__(self, train_raw, test_raw, word2vec, word_vector_size, dim,
                mode, answer_module, input_mask_mode, memory_hops, batch_size, l2,
                normalize_attention, batch_norm, dropout, **kwargs):

        print("==> not used params in DMN class:", kwargs.keys())

        self.vocab = {}
        self.ivocab = {}

        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.mode = mode
        self.answer_module = answer_module
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        self.batch_size = batch_size
        self.l2 = l2
        self.normalize_attention = normalize_attention
        self.batch_norm = batch_norm
        self.dropout = dropout

        # Process the input into its different parts and calculate the input mask
        self.train_input = train_raw['inputs']
        self.train_q = train_raw['questions']
        self.train_answer = train_raw['answers']
        self.train_input_mask=train_raw['input_masks']
        self.test_input = test_raw['inputs']
        self.test_q = test_raw['questions']
        self.test_answer = test_raw['answers']
        self.test_input_mask=test_raw['input_masks']

        self.train_fact_count = [len(mask) for mask in self.train_input_mask]
        self.test_fact_count = [len(mask) for mask in self.test_input_mask]

        #self.train_input, self.train_q, self.train_answer, self.train_fact_count, self.train_input_mask = self._process_input(babi_train_raw)
        #self.test_input, self.test_q, self.test_answer, self.test_fact_count, self.test_input_mask = self._process_input(babi_test_raw)
        # Vocab size might only be used for the answer module and we only want to give that two choices
        self.vocab_size = 2
        print(np.array(self.train_answer).shape,
              np.array(self.train_input[0]).shape, np.array(self.train_fact_count).shape,
              np.array(self.train_input_mask).shape)
        # print("facts", self.train_fact_count[:2])
        # print("mask", self.train_input_mask[:2])
        #raise

        # Get the size of the word_vector from the data itself as this can be variable now
        self.word_vector_size = np.array(self.train_input[0]).shape[1]
        print(self.word_vector_size)

        self.input_var = T.tensor3('input_var') # (batch_size, seq_len, glove_dim)
        self.q_var = T.tensor3('question_var') # as self.input_var
        self.answer_var = T.ivector('answer_var') # answer of example in minibatch
        self.fact_count_var = T.ivector('fact_count_var') # number of facts in the example of minibatch
        self.input_mask_var = T.imatrix('input_mask_var') # (batch_size, indices)

        print("==> building input module")
        self.W_inp_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_inp_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_inp_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        input_var_shuffled = self.input_var.dimshuffle(1, 2, 0)
        inp_dummy = theano.shared(np.zeros((self.dim, self.batch_size), dtype=floatX))
        inp_c_history, _ = theano.scan(fn=self.input_gru_step,
                            sequences=input_var_shuffled,
                            outputs_info=T.zeros_like(inp_dummy))

        inp_c_history_shuffled = inp_c_history.dimshuffle(2, 0, 1)

        inp_c_list = []
        inp_c_mask_list = []
        for batch_index in range(self.batch_size):
            taken = inp_c_history_shuffled[batch_index].take(self.input_mask_var[batch_index, :self.fact_count_var[batch_index]], axis=0)
            inp_c_list.append(T.concatenate([taken, T.zeros((self.input_mask_var.shape[1] - taken.shape[0], self.dim), floatX)]))
            inp_c_mask_list.append(T.concatenate([T.ones((taken.shape[0],), np.int32), T.zeros((self.input_mask_var.shape[1] - taken.shape[0],), np.int32)]))

        self.inp_c = T.stack(inp_c_list).dimshuffle(1, 2, 0)
        inp_c_mask = T.stack(inp_c_mask_list).dimshuffle(1, 0)

        q_var_shuffled = self.q_var.dimshuffle(1, 2, 0)
        q_dummy = theano.shared(np.zeros((self.dim, self.batch_size), dtype=floatX))
        q_q_history, _ = theano.scan(fn=self.input_gru_step,
                            sequences=q_var_shuffled,
                            outputs_info=T.zeros_like(q_dummy))
        self.q_q = q_q_history[-1]


        print("==> creating parameters for memory module")
        self.W_mem_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_mem_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_mem_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_1 = nn_utils.normal_param(std=0.1, shape=(self.dim, 7 * self.dim + 0))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1,))


        print("==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops)
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(self.GRU_update(memory[iter - 1], current_episode,
                                          self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res,
                                          self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                                          self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid))

        last_mem_raw = memory[-1].dimshuffle((1, 0))

        net = layers.InputLayer(shape=(self.batch_size, self.dim), input_var=last_mem_raw)
        if self.batch_norm:
            net = layers.BatchNormLayer(incoming=net)
        if self.dropout > 0 and self.mode == 'train':
            net = layers.DropoutLayer(net, p=self.dropout)
        last_mem = layers.get_output(net).dimshuffle((1, 0))


        print("==> building answer module")
        self.W_a = nn_utils.normal_param(std=0.1, shape=(self.vocab_size, self.dim))

        if self.answer_module == 'feedforward':
            self.prediction = nn_utils.softmax(T.dot(self.W_a, last_mem))

        elif self.answer_module == 'recurrent':
            self.W_ans_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
            self.b_ans_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))

            self.W_ans_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
            self.b_ans_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))

            self.W_ans_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
            self.b_ans_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))

            def answer_step(prev_a, prev_y):
                a = self.GRU_update(prev_a, T.concatenate([prev_y, self.q_q]),
                                  self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res,
                                  self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                                  self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid)

                y = nn_utils.softmax(T.dot(self.W_a, a))
                return [a, y]

            # TODO: add conditional ending
            dummy = theano.shared(np.zeros((self.vocab_size, self.batch_size), dtype=floatX))
            results, updates = theano.scan(fn=self.answer_step,
                outputs_info=[last_mem, T.zeros_like(dummy)], #(last_mem, y)
                n_steps=1)
            self.prediction = results[1][-1]

        else:
            raise Exception("invalid answer_module")

        self.prediction = self.prediction.dimshuffle(1, 0)

        self.params = [self.W_inp_res_in, self.W_inp_res_hid, self.b_inp_res,
                  self.W_inp_upd_in, self.W_inp_upd_hid, self.b_inp_upd,
                  self.W_inp_hid_in, self.W_inp_hid_hid, self.b_inp_hid,
                  self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res,
                  self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                  self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid, #self.W_b
                  self.W_1, self.W_2, self.b_1, self.b_2, self.W_a]

        if self.answer_module == 'recurrent':
            self.params = self.params + [self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res,
                              self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                              self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid]


        print("==> building loss layer and computing updates")
        self.loss_ce = T.nnet.categorical_crossentropy(self.prediction, self.answer_var).mean()

        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0

        self.loss = self.loss_ce + self.loss_l2

        updates = lasagne.updates.adadelta(self.loss, self.params)
        #updates = lasagne.updates.momentum(self.loss, self.params, learning_rate=0.001)

        if self.mode == 'train':
            print("==> compiling train_fn")
            self.train_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var,
                                                    self.fact_count_var, self.input_mask_var],
                                            outputs=[self.prediction, self.loss],
                                            updates=updates)

        print("==> compiling test_fn")
        self.test_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var,
                                               self.fact_count_var, self.input_mask_var],
                                       outputs=[self.prediction, self.loss])
示例#2
0
    def __init__(self, train_raw, test_raw, word2vec, word_vector_size, dim,
                 mode, answer_module, input_mask_mode, memory_hops, l2,
                 normalize_attention, num_cnn_layers, vocab_len,
                 maximum_doc_len, char_vocab, **kwargs):

        print("==> not used params in DMN class:", kwargs.keys())
        self.vocab = {}
        self.ivocab = {}

        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.mode = mode
        self.answer_module = answer_module
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        self.l2 = l2
        self.normalize_attention = normalize_attention
        self.final_num_layers = num_cnn_layers
        self.vocab_length = vocab_len
        self.max_doc_length = maximum_doc_len
        self.char_vocab = char_vocab

        self.cnn_layer_length = 100

        # Process the input into its different parts and calculate the input mask
        self.train_input_raw, self.train_q_raw, self.train_answer, self.train_input_mask = self._process_input(
            train_raw)
        self.test_input_raw, self.test_q_raw, self.test_answer, self.test_input_mask = self._process_input(
            test_raw)
        self.vocab_size = len(self.vocab)

        print(type(self.train_input_raw), len(self.train_input_raw))
        self.train_input = self.build_cnn(self.train_input_raw)
        self.test_input = self.build_cnn(self.test_input_raw)
        self.train_q = self.build_cnn(self.train_q_raw)
        self.test_q = self.build_cnn(self.test_q_raw)
        #print(self.train_input.shape.eval(), self.train_input.__getitem__(0).eval())
        #print(type(self.train_input), len(self.train_input), len(self.train_input[1]))
        #print(type(self.train_input), len(self.train_input), len(self.train_input[1]))
        print(self.train_answer)

        self.input_var = T.matrix('input_var')  #previously matrix
        self.q_var = T.matrix('question_var')
        self.answer_var = T.iscalar('answer_var')
        self.input_mask_var = T.ivector('input_mask_var')

        #CNN

        print("==> building input module")
        self.W_inp_res_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.cnn_layer_length))
        self.W_inp_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_inp_upd_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.cnn_layer_length))
        self.W_inp_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_inp_hid_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.cnn_layer_length))
        self.W_inp_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        # self.input_var = self.build_cnn(self.input_var)
        # self.q_var = self.build_cnn(self.q_var)

        inp_c_history, _ = theano.scan(fn=self.input_gru_step,
                                       sequences=self.input_var,
                                       outputs_info=T.zeros_like(
                                           self.b_inp_hid))

        self.inp_c = inp_c_history.take(self.input_mask_var, axis=0)

        self.q_q, _ = theano.scan(fn=self.input_gru_step,
                                  sequences=self.q_var,
                                  outputs_info=T.zeros_like(self.b_inp_hid))

        self.q_q = self.q_q[-1]

        print("==> creating parameters for memory module")
        self.W_mem_res_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_upd_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_hid_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_1 = nn_utils.normal_param(std=0.1,
                                         shape=(self.dim, 7 * self.dim + 2))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim, ))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1, ))

        print(
            "==> building episodic memory module (fixed number of steps: %d)" %
            self.memory_hops)
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(
                self.GRU_update(memory[iter - 1], current_episode,
                                self.W_mem_res_in, self.W_mem_res_hid,
                                self.b_mem_res, self.W_mem_upd_in,
                                self.W_mem_upd_hid, self.b_mem_upd,
                                self.W_mem_hid_in, self.W_mem_hid_hid,
                                self.b_mem_hid))

        last_mem = memory[-1]

        print("==> building answer module")
        self.W_a = nn_utils.normal_param(std=0.1,
                                         shape=(self.vocab_size, self.dim))

        if self.answer_module == 'feedforward':
            self.prediction = nn_utils.softmax(T.dot(self.W_a, last_mem))

        elif self.answer_module == 'recurrent':
            self.W_ans_res_in = nn_utils.normal_param(
                std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_res_hid = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
            self.b_ans_res = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

            self.W_ans_upd_in = nn_utils.normal_param(
                std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_upd_hid = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
            self.b_ans_upd = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

            self.W_ans_hid_in = nn_utils.normal_param(
                std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_hid_hid = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
            self.b_ans_hid = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

            def answer_step(prev_a, prev_y):
                a = self.GRU_update(prev_a, T.concatenate([prev_y, self.q_q]),
                                    self.W_ans_res_in, self.W_ans_res_hid,
                                    self.b_ans_res, self.W_ans_upd_in,
                                    self.W_ans_upd_hid, self.b_ans_upd,
                                    self.W_ans_hid_in, self.W_ans_hid_hid,
                                    self.b_ans_hid)

                y = nn_utils.softmax(T.dot(self.W_a, a))
                return [a, y]

            # TODO: add conditional ending
            dummy = theano.shared(np.zeros((self.vocab_size, ), dtype=floatX))
            results, updates = theano.scan(
                fn=answer_step,
                outputs_info=[last_mem, T.zeros_like(dummy)],
                n_steps=1)
            self.prediction = results[1][-1]

        else:
            raise Exception("invalid answer_module")

        print("==> collecting all parameters")
        self.params = [
            self.W_inp_res_in, self.W_inp_res_hid, self.b_inp_res,
            self.W_inp_upd_in, self.W_inp_upd_hid, self.b_inp_upd,
            self.W_inp_hid_in, self.W_inp_hid_hid, self.b_inp_hid,
            self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res,
            self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
            self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid, self.W_b,
            self.W_1, self.W_2, self.b_1, self.b_2, self.W_a, self.net_w
        ]  #, self.net4_w]#, self.net_w]
        #TODO add in the cnn params
        #raise

        if self.answer_module == 'recurrent':
            self.params = self.params + [
                self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res,
                self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid
            ]

        print("==> building loss layer and computing updates")
        self.loss_ce = T.nnet.categorical_crossentropy(
            self.prediction.dimshuffle('x', 0), T.stack([self.answer_var]))[0]
        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0

        self.loss = self.loss_ce + self.loss_l2

        updates = lasagne.updates.adadelta(self.loss, self.params)

        if self.mode == 'train':
            print("==> compiling train_fn")
            self.train_fn = theano.function(
                inputs=[
                    self.input_var, self.q_var, self.answer_var,
                    self.input_mask_var
                ],
                outputs=[self.prediction, self.loss],
                updates=updates)

        print("==> compiling test_fn")
        self.test_fn = theano.function(inputs=[
            self.input_var, self.q_var, self.answer_var, self.input_mask_var
        ],
                                       outputs=[
                                           self.prediction, self.loss,
                                           self.inp_c, self.q_q, last_mem
                                       ])

        if self.mode == 'train':
            print("==> computing gradients (for debugging)")
            gradient = T.grad(self.loss, self.params)
            self.get_gradient_fn = theano.function(inputs=[
                self.input_var, self.q_var, self.answer_var,
                self.input_mask_var
            ],
                                                   outputs=gradient)
示例#3
0
    def __init__(
        self,
        train_raw,
        test_raw,
        word2vec,
        word_vector_size,
        dim,
        mode,
        answer_module,
        input_mask_mode,
        memory_hops,
        l2,
        normalize_attention,
        num_cnn_layers,
        vocab_len,
        maximum_doc_len,
        char_vocab,
        **kwargs
    ):

        print("==> not used params in DMN class:", kwargs.keys())
        self.vocab = {}
        self.ivocab = {}

        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.mode = mode
        self.answer_module = answer_module
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        self.l2 = l2
        self.normalize_attention = normalize_attention
        self.final_num_layers = num_cnn_layers
        self.vocab_length = vocab_len
        self.max_doc_length = maximum_doc_len
        self.char_vocab = char_vocab

        self.cnn_layer_length = 100

        # Process the input into its different parts and calculate the input mask
        self.train_input_raw, self.train_q_raw, self.train_answer, self.train_input_mask = self._process_input(
            train_raw
        )
        self.test_input_raw, self.test_q_raw, self.test_answer, self.test_input_mask = self._process_input(test_raw)
        self.vocab_size = len(self.vocab)

        print(type(self.train_input_raw), len(self.train_input_raw))
        self.train_input = self.build_cnn(self.train_input_raw)
        self.test_input = self.build_cnn(self.test_input_raw)
        self.train_q = self.build_cnn(self.train_q_raw)
        self.test_q = self.build_cnn(self.test_q_raw)
        # print(self.train_input.shape.eval(), self.train_input.__getitem__(0).eval())
        # print(type(self.train_input), len(self.train_input), len(self.train_input[1]))
        # print(type(self.train_input), len(self.train_input), len(self.train_input[1]))
        print(self.train_answer)

        self.input_var = T.matrix("input_var")  # previously matrix
        self.q_var = T.matrix("question_var")
        self.answer_var = T.iscalar("answer_var")
        self.input_mask_var = T.ivector("input_mask_var")

        # CNN

        print("==> building input module")
        self.W_inp_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.cnn_layer_length))
        self.W_inp_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_inp_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.cnn_layer_length))
        self.W_inp_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_inp_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.cnn_layer_length))
        self.W_inp_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        # self.input_var = self.build_cnn(self.input_var)
        # self.q_var = self.build_cnn(self.q_var)

        inp_c_history, _ = theano.scan(
            fn=self.input_gru_step, sequences=self.input_var, outputs_info=T.zeros_like(self.b_inp_hid)
        )

        self.inp_c = inp_c_history.take(self.input_mask_var, axis=0)

        self.q_q, _ = theano.scan(
            fn=self.input_gru_step, sequences=self.q_var, outputs_info=T.zeros_like(self.b_inp_hid)
        )

        self.q_q = self.q_q[-1]

        print("==> creating parameters for memory module")
        self.W_mem_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_mem_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_mem_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_1 = nn_utils.normal_param(std=0.1, shape=(self.dim, 7 * self.dim + 2))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1,))

        print("==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops)
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(
                self.GRU_update(
                    memory[iter - 1],
                    current_episode,
                    self.W_mem_res_in,
                    self.W_mem_res_hid,
                    self.b_mem_res,
                    self.W_mem_upd_in,
                    self.W_mem_upd_hid,
                    self.b_mem_upd,
                    self.W_mem_hid_in,
                    self.W_mem_hid_hid,
                    self.b_mem_hid,
                )
            )

        last_mem = memory[-1]

        print("==> building answer module")
        self.W_a = nn_utils.normal_param(std=0.1, shape=(self.vocab_size, self.dim))

        if self.answer_module == "feedforward":
            self.prediction = nn_utils.softmax(T.dot(self.W_a, last_mem))

        elif self.answer_module == "recurrent":
            self.W_ans_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
            self.b_ans_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))

            self.W_ans_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
            self.b_ans_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))

            self.W_ans_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
            self.b_ans_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))

            def answer_step(prev_a, prev_y):
                a = self.GRU_update(
                    prev_a,
                    T.concatenate([prev_y, self.q_q]),
                    self.W_ans_res_in,
                    self.W_ans_res_hid,
                    self.b_ans_res,
                    self.W_ans_upd_in,
                    self.W_ans_upd_hid,
                    self.b_ans_upd,
                    self.W_ans_hid_in,
                    self.W_ans_hid_hid,
                    self.b_ans_hid,
                )

                y = nn_utils.softmax(T.dot(self.W_a, a))
                return [a, y]

            # TODO: add conditional ending
            dummy = theano.shared(np.zeros((self.vocab_size,), dtype=floatX))
            results, updates = theano.scan(fn=answer_step, outputs_info=[last_mem, T.zeros_like(dummy)], n_steps=1)
            self.prediction = results[1][-1]

        else:
            raise Exception("invalid answer_module")

        print("==> collecting all parameters")
        self.params = [
            self.W_inp_res_in,
            self.W_inp_res_hid,
            self.b_inp_res,
            self.W_inp_upd_in,
            self.W_inp_upd_hid,
            self.b_inp_upd,
            self.W_inp_hid_in,
            self.W_inp_hid_hid,
            self.b_inp_hid,
            self.W_mem_res_in,
            self.W_mem_res_hid,
            self.b_mem_res,
            self.W_mem_upd_in,
            self.W_mem_upd_hid,
            self.b_mem_upd,
            self.W_mem_hid_in,
            self.W_mem_hid_hid,
            self.b_mem_hid,
            self.W_b,
            self.W_1,
            self.W_2,
            self.b_1,
            self.b_2,
            self.W_a,
            self.net_w,
        ]  # , self.net4_w]#, self.net_w]
        # TODO add in the cnn params
        # raise

        if self.answer_module == "recurrent":
            self.params = self.params + [
                self.W_ans_res_in,
                self.W_ans_res_hid,
                self.b_ans_res,
                self.W_ans_upd_in,
                self.W_ans_upd_hid,
                self.b_ans_upd,
                self.W_ans_hid_in,
                self.W_ans_hid_hid,
                self.b_ans_hid,
            ]

        print("==> building loss layer and computing updates")
        self.loss_ce = T.nnet.categorical_crossentropy(self.prediction.dimshuffle("x", 0), T.stack([self.answer_var]))[
            0
        ]
        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0

        self.loss = self.loss_ce + self.loss_l2

        updates = lasagne.updates.adadelta(self.loss, self.params)

        if self.mode == "train":
            print("==> compiling train_fn")
            self.train_fn = theano.function(
                inputs=[self.input_var, self.q_var, self.answer_var, self.input_mask_var],
                outputs=[self.prediction, self.loss],
                updates=updates,
            )

        print("==> compiling test_fn")
        self.test_fn = theano.function(
            inputs=[self.input_var, self.q_var, self.answer_var, self.input_mask_var],
            outputs=[self.prediction, self.loss, self.inp_c, self.q_q, last_mem],
        )

        if self.mode == "train":
            print("==> computing gradients (for debugging)")
            gradient = T.grad(self.loss, self.params)
            self.get_gradient_fn = theano.function(
                inputs=[self.input_var, self.q_var, self.answer_var, self.input_mask_var], outputs=gradient
            )
示例#4
0
    def __init__(self, train_raw, test_raw, word2vec, word_vector_size, dim,
                 mode, answer_module, input_mask_mode, memory_hops, l2,
                 normalize_attention, **kwargs):

        print("==> not used params in DMN class:", kwargs.keys())
        self.vocab = {}
        self.ivocab = {}

        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.mode = mode
        self.answer_module = answer_module
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        self.l2 = l2
        self.normalize_attention = normalize_attention

        self.to_return = "word2vec"  #TODO pass these in once updates are wrapped in
        self.encoder_decoder = None
        self.vocab_dict = {}

        # Process the input into its different parts and calculate the input mask
        self.train_input = train_raw['inputs']
        self.train_q = train_raw['questions']
        self.train_answer = train_raw['answers']
        self.train_input_mask = train_raw['input_masks']
        self.test_input = test_raw['inputs']
        self.test_q = test_raw['questions']
        self.test_answer = test_raw['answers']
        self.test_input_mask = test_raw['input_masks']
        # self.train_input, self.train_q, self.train_answer, self.train_input_mask = self._process_input(train_raw)
        # self.test_input, self.test_q, self.test_answer, self.test_input_mask = self._process_input(test_raw)
        self.vocab_size = 2  #vocab size might only be used for the answer module and we only want to give that two choices
        # print(self.train_answer[0])
        # print(self.train_input_mask[0])
        # print(np.array(self.train_input[0]).shape)

        # Get the size of the word_vector from the data itself as this can be variable now
        self.word_vector_size = np.array(self.train_input[0]).shape[1]

        self.input_var = T.matrix('input_var')
        self.q_var = T.matrix('question_var')
        self.answer_var = T.iscalar('answer_var')
        self.input_mask_var = T.ivector('input_mask_var')
        print(len(self.train_answer), sum(self.train_answer),
              len(self.test_answer), sum(self.test_answer))

        print("==> building input module")
        self.W_inp_res_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_inp_upd_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_inp_hid_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        inp_c_history, _ = theano.scan(fn=self.input_gru_step,
                                       sequences=self.input_var,
                                       outputs_info=T.zeros_like(
                                           self.b_inp_hid))

        self.inp_c = inp_c_history.take(self.input_mask_var, axis=0)

        self.q_q, _ = theano.scan(fn=self.input_gru_step,
                                  sequences=self.q_var,
                                  outputs_info=T.zeros_like(self.b_inp_hid))

        self.q_q = self.q_q[-1]

        print("==> creating parameters for memory module")
        self.W_mem_res_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_upd_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_hid_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_1 = nn_utils.normal_param(std=0.1,
                                         shape=(self.dim, 7 * self.dim + 2))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim, ))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1, ))

        print(
            "==> building episodic memory module (fixed number of steps: %d)" %
            self.memory_hops)
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(
                self.GRU_update(memory[iter - 1], current_episode,
                                self.W_mem_res_in, self.W_mem_res_hid,
                                self.b_mem_res, self.W_mem_upd_in,
                                self.W_mem_upd_hid, self.b_mem_upd,
                                self.W_mem_hid_in, self.W_mem_hid_hid,
                                self.b_mem_hid))

        last_mem = memory[-1]

        print("==> building answer module")
        self.W_a = nn_utils.normal_param(std=0.1,
                                         shape=(self.vocab_size, self.dim))

        if self.answer_module == 'feedforward':
            self.prediction = nn_utils.softmax(T.dot(self.W_a, last_mem))

        elif self.answer_module == 'recurrent':
            self.W_ans_res_in = nn_utils.normal_param(
                std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_res_hid = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
            self.b_ans_res = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

            self.W_ans_upd_in = nn_utils.normal_param(
                std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_upd_hid = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
            self.b_ans_upd = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

            self.W_ans_hid_in = nn_utils.normal_param(
                std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_hid_hid = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
            self.b_ans_hid = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

            def answer_step(prev_a, prev_y):
                a = self.GRU_update(prev_a, T.concatenate([prev_y, self.q_q]),
                                    self.W_ans_res_in, self.W_ans_res_hid,
                                    self.b_ans_res, self.W_ans_upd_in,
                                    self.W_ans_upd_hid, self.b_ans_upd,
                                    self.W_ans_hid_in, self.W_ans_hid_hid,
                                    self.b_ans_hid)

                y = nn_utils.softmax(T.dot(self.W_a, a))
                return [a, y]

            # TODO: add conditional ending
            dummy = theano.shared(np.zeros((self.vocab_size, ), dtype=floatX))
            results, updates = theano.scan(
                fn=answer_step,
                outputs_info=[last_mem, T.zeros_like(dummy)],
                n_steps=1)
            self.prediction = results[1][-1]

        else:
            raise Exception("invalid answer_module")

        print("==> collecting all parameters")
        self.params = [
            self.W_inp_res_in, self.W_inp_res_hid, self.b_inp_res,
            self.W_inp_upd_in, self.W_inp_upd_hid, self.b_inp_upd,
            self.W_inp_hid_in, self.W_inp_hid_hid, self.b_inp_hid,
            self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res,
            self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
            self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid, self.W_b,
            self.W_1, self.W_2, self.b_1, self.b_2, self.W_a
        ]

        if self.answer_module == 'recurrent':
            self.params = self.params + [
                self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res,
                self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid
            ]

        print("==> building loss layer and computing updates")
        self.loss_ce = T.nnet.categorical_crossentropy(
            self.prediction.dimshuffle('x', 0), T.stack([self.answer_var]))[0]
        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0

        self.loss = self.loss_ce + self.loss_l2

        updates = lasagne.updates.adadelta(self.loss, self.params)

        if self.mode == 'train':
            print("==> compiling train_fn")
            self.train_fn = theano.function(
                inputs=[
                    self.input_var, self.q_var, self.answer_var,
                    self.input_mask_var
                ],
                outputs=[self.prediction, self.loss],
                updates=updates)

        print("==> compiling test_fn")
        self.test_fn = theano.function(inputs=[
            self.input_var, self.q_var, self.answer_var, self.input_mask_var
        ],
                                       outputs=[
                                           self.prediction, self.loss,
                                           self.inp_c, self.q_q, last_mem
                                       ])

        if self.mode == 'train':
            print("==> computing gradients (for debugging)")
            gradient = T.grad(self.loss, self.params)
            self.get_gradient_fn = theano.function(inputs=[
                self.input_var, self.q_var, self.answer_var,
                self.input_mask_var
            ],
                                                   outputs=gradient)
示例#5
0
文件: dmn_basic.py 项目: Lab41/pythia
    def __init__(self, train_raw, test_raw, word2vec, word_vector_size,
                dim, mode, answer_module, input_mask_mode, memory_hops, l2, 
                normalize_attention, **kwargs):

        print("==> not used params in DMN class:", kwargs.keys())
        self.vocab = {}
        self.ivocab = {}
        
        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.mode = mode
        self.answer_module = answer_module
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        self.l2 = l2
        self.normalize_attention = normalize_attention

        self.to_return = "word2vec" #TODO pass these in once updates are wrapped in
        self.encoder_decoder = None
        self.vocab_dict = {}
        
    # Process the input into its different parts and calculate the input mask
        self.train_input = train_raw['inputs']
        self.train_q = train_raw['questions']
        self.train_answer = train_raw['answers']
        self.train_input_mask=train_raw['input_masks']
        self.test_input = test_raw['inputs']
        self.test_q = test_raw['questions']
        self.test_answer = test_raw['answers']
        self.test_input_mask=test_raw['input_masks']
        # self.train_input, self.train_q, self.train_answer, self.train_input_mask = self._process_input(train_raw)
        # self.test_input, self.test_q, self.test_answer, self.test_input_mask = self._process_input(test_raw)
        self.vocab_size = 2 #vocab size might only be used for the answer module and we only want to give that two choices
        # print(self.train_answer[0])
        # print(self.train_input_mask[0])
        # print(np.array(self.train_input[0]).shape)

        # Get the size of the word_vector from the data itself as this can be variable now
        self.word_vector_size = np.array(self.train_input[0]).shape[1]

        self.input_var = T.matrix('input_var')
        self.q_var = T.matrix('question_var')
        self.answer_var = T.iscalar('answer_var')
        self.input_mask_var = T.ivector('input_mask_var')
        print(len(self.train_answer), sum(self.train_answer), len(self.test_answer), sum(self.test_answer))
        
            
        print("==> building input module")
        self.W_inp_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_inp_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_inp_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        inp_c_history, _ = theano.scan(fn=self.input_gru_step, 
                    sequences=self.input_var,
                    outputs_info=T.zeros_like(self.b_inp_hid))

        self.inp_c = inp_c_history.take(self.input_mask_var, axis=0)
        
        self.q_q, _ = theano.scan(fn = self.input_gru_step,
			sequences = self.q_var,
			outputs_info = T.zeros_like(self.b_inp_hid))

        self.q_q = self.q_q[-1]


        print("==> creating parameters for memory module")
        self.W_mem_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_mem_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_mem_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_1 = nn_utils.normal_param(std=0.1, shape=(self.dim, 7 * self.dim + 2))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1,))
        

        print("==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops)
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(self.GRU_update(memory[iter - 1], current_episode,
                                          self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res, 
                                          self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                                          self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid))
        
        last_mem = memory[-1]
        
        print("==> building answer module")
        self.W_a = nn_utils.normal_param(std=0.1, shape=(self.vocab_size, self.dim))
        
        if self.answer_module == 'feedforward':
            self.prediction = nn_utils.softmax(T.dot(self.W_a, last_mem))
        
        elif self.answer_module == 'recurrent':
            self.W_ans_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
            self.b_ans_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))
            
            self.W_ans_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
            self.b_ans_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))
            
            self.W_ans_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
            self.b_ans_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
            def answer_step(prev_a, prev_y):
                a = self.GRU_update(prev_a, T.concatenate([prev_y, self.q_q]),
                                  self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res, 
                                  self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                                  self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid)
                
                y = nn_utils.softmax(T.dot(self.W_a, a))
                return [a, y]
            
            # TODO: add conditional ending
            dummy = theano.shared(np.zeros((self.vocab_size, ), dtype=floatX))
            results, updates = theano.scan(fn=answer_step,
                outputs_info=[last_mem, T.zeros_like(dummy)],
                n_steps=1)
            self.prediction = results[1][-1]
        
        else:
            raise Exception("invalid answer_module")
        
        
        print("==> collecting all parameters")
        self.params = [self.W_inp_res_in, self.W_inp_res_hid, self.b_inp_res, 
                  self.W_inp_upd_in, self.W_inp_upd_hid, self.b_inp_upd,
                  self.W_inp_hid_in, self.W_inp_hid_hid, self.b_inp_hid,
                  self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res, 
                  self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                  self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid,
                  self.W_b, self.W_1, self.W_2, self.b_1, self.b_2, self.W_a]
        
        if self.answer_module == 'recurrent':
            self.params = self.params + [self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res, 
                              self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                              self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid]
        
        
        print("==> building loss layer and computing updates")
        self.loss_ce = T.nnet.categorical_crossentropy(self.prediction.dimshuffle('x', 0), T.stack([self.answer_var]))[0]
        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0
        
        self.loss = self.loss_ce + self.loss_l2
        
        updates = lasagne.updates.adadelta(self.loss, self.params)
        
        if self.mode == 'train':
            print("==> compiling train_fn")
            self.train_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.input_mask_var], 
                                       outputs=[self.prediction, self.loss],
                                       updates=updates)
        
        print("==> compiling test_fn")
        self.test_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.input_mask_var],
                                  outputs=[self.prediction, self.loss, self.inp_c, self.q_q, last_mem])
        
        
        if self.mode == 'train':
            print("==> computing gradients (for debugging)")
            gradient = T.grad(self.loss, self.params)
            self.get_gradient_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.input_mask_var], outputs=gradient)