Exemple #1
0
    def new_episode(self, mem):
        tmp = theano.shared(np.zeros([self.batch_size, 1], dtype='float32'))
        z, z_updates = theano.scan(fn=self.episode_compute_z,
                                   sequences=self.inp_c,
                                   non_sequences=[mem, self.q_q],
                                   outputs_info=T.zeros_like(tmp))

        g, g_updates = theano.scan(
            fn=self.episode_compute_g,
            sequences=z,
            non_sequences=z,
        )

        #'''
        if (self.normalize_attention):
            g = nn_utils.softmax(g)
        #'''

        self.attentions.append(g)

        e, e_updates = theano.scan(fn=self.episode_attend,
                                   sequences=[self.inp_c, g],
                                   outputs_info=T.zeros_like(self.inp_c[0]))

        return e[-1]
    def new_episode(self, mem):
        #epi_dummy = theano.shared(np.zeros((self.dim,), dtype = floatX))
        g, g_updates = theano.scan(
            fn=self.new_attention_step,
            sequences=self.inp_batch_dimshuffled,  #980 x 512 x 50
            non_sequences=[mem, self.q_q],
            #outputs_info=T.zeros_like(epi_dummy))
            outputs_info=T.zeros_like(self.inp_batch_dimshuffled[0][0]),
            truncate_gradient=self.truncate_gradient)

        if (self.normalize_attention):
            g = nn_utils.softmax(g)

        #epi_dummy2 = theano.shared(np.zeros((self.dim,self.dim), dtype = floatX))
        e, e_updates = theano.scan(
            fn=self.new_episode_step,
            sequences=[self.inp_batch_dimshuffled, g],
            #outputs_info=T.zeros_like(epi_dummy2))
            outputs_info=T.zeros_like(self.inp_batch_dimshuffled[0]),
            truncate_gradient=self.truncate_gradient)

        e_list = []
        for index in range(self.batch_size * self.story_len):
            e_list.append(e[-1, :, index])
        return T.stack(e_list).dimshuffle((1, 0))
 def answer_step(prev_a, prev_y):
     a = self.GRU_update(prev_a, T.concatenate([prev_y, self.q_q]),
                       self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res, 
                       self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                       self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid)
     
     y = nn_utils.softmax(T.dot(self.W_a, a))
     return [a, y]
 def answer_step(prev_a, prev_y):
     a = self.GRU_update(prev_a, T.concatenate([prev_y, self.q_q]),
                       self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res, 
                       self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                       self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid)
     
     y = nn_utils.softmax(T.dot(self.W_a, a))
     return [a, y]
            def answer_step(prev_a, prev_y):
                a = self.GRU_update(prev_a, T.concatenate([prev_y, self.q_q]),
                                    self.W_ans_res_in, self.W_ans_res_hid,
                                    self.b_ans_res, self.W_ans_upd_in,
                                    self.W_ans_upd_hid, self.b_ans_upd,
                                    self.W_ans_hid_in, self.W_ans_hid_hid,
                                    self.b_ans_hid)

                y = nn_utils.softmax(T.dot(self.W_a, a))
                return [
                    a, y
                ]  #, theano.scan_module.until(n>=max_n)) # or argmax==self.end_tag)
 def new_episode(self, mem):
     g, g_updates = theano.scan(fn=self.new_attention_step,
         sequences=self.inp_c,
         non_sequences=[mem, self.q_q],
         outputs_info=T.zeros_like(self.inp_c[0][0])) 
     
     if (self.normalize_attention):
         g = nn_utils.softmax(g)
     
     e, e_updates = theano.scan(fn=self.new_episode_step,
         sequences=[self.inp_c, g],
         outputs_info=T.zeros_like(self.inp_c[0]))
     
     return e[-1]
Exemple #7
0
    def new_episode(self, mem):
        g, g_updates = theano.scan(fn=self.new_attention_step,
                                   sequences=self.inp_c,
                                   non_sequences=[mem, self.q_q],
                                   outputs_info=T.zeros_like(self.inp_c[0][0]))

        if (self.normalize_attention):
            g = nn_utils.softmax(g)

        e, e_updates = theano.scan(fn=self.new_episode_step,
                                   sequences=[self.inp_c, g],
                                   outputs_info=T.zeros_like(self.inp_c[0]))

        return e[-1]
    def new_img_episode(self, mem):
        g, g_updates = theano.scan(fn=self.new_img_attention_step,
            sequences=self.img_inp_c,
            non_sequences=[mem, self.q_q],
            outputs_info=T.zeros_like(self.img_inp_c[0][0]))

        if (self.normalize_attention):
            g = nn_utils.softmax(g)

        e, e_updates = theano.scan(fn=self.new_img_episode_step,
            sequences=[self.img_inp_c, g],
            outputs_info=T.zeros_like(self.img_inp_c[0]))

        e_list = []
        for index in range(self.batch_size):
            e_list.append(e[self.img_seq_len - 1, :, index])
        return T.stack(e_list).dimshuffle((1, 0))
Exemple #9
0
    def new_episode(self, mem):
        g, g_updates = theano.scan(fn=self.new_attention_step,
                                   sequences=self.inp_c,
                                   non_sequences=[mem, self.q_q],
                                   outputs_info=T.zeros_like(self.inp_c[0][0]))

        if (self.normalize_attention):
            g = nn_utils.softmax(g)

        e, e_updates = theano.scan(fn=self.new_episode_step,
                                   sequences=[self.inp_c, g],
                                   outputs_info=T.zeros_like(self.inp_c[0]))

        e_list = []
        for index in range(self.batch_size * self.story_len):
            e_list.append(e[-1, :, index])
        return T.stack(e_list).dimshuffle((1, 0))
    def new_episode(self, mem):
        z, z_updates = theano.scan(fn=self.episode_compute_z,
            sequences=self.inp_c,
            non_sequences=[mem, self.q_q],
            outputs_info=T.zeros_like(self.b_2))

        g, g_updates = theano.scan(fn=self.episode_compute_g,
            sequences=z,
            non_sequences=z,)
            
        if (self.normalize_attention):
            g = nn_utils.softmax(g) 

        self.attentions.append(g)

        e, e_updates = theano.scan(fn=self.episode_attend,
            sequences=[self.inp_c, g],
            outputs_info=T.zeros_like(self.inp_c[0]))
        
        return e[-1] 
Exemple #11
0
    def new_episode(self, mem):
        z, z_updates = theano.scan(fn=self.episode_compute_z,
            sequences=self.inp_c,
            non_sequences=[mem, self.q_q],
            outputs_info=T.zeros_like(self.b_2))

        g, g_updates = theano.scan(fn=self.episode_compute_g,
            sequences=z,
            non_sequences=z,)
            
        if (self.normalize_attention):
            g = nn_utils.softmax(g) 

        self.attentions.append(g)

        e, e_updates = theano.scan(fn=self.episode_attend,
            sequences=[self.inp_c, g],
            outputs_info=T.zeros_like(self.inp_c[0]))
        
        return e[-1] 
    def __init__(self, babi_train_raw, babi_test_raw, word2vec,
                 word_vector_size, dim, mode, answer_module, input_mask_mode,
                 memory_hops, batch_size, l2, normalize_attention, **kwargs):

        print "==> not used params in DMN class:", kwargs.keys()

        self.vocab = {}
        self.ivocab = {}

        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.mode = mode
        self.answer_module = answer_module
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        self.batch_size = batch_size
        self.l2 = l2
        self.normalize_attention = normalize_attention

        self.max_fact_count = 0

        self.train_input, self.train_q, self.train_answer, self.train_fact_count, self.train_input_mask = self._process_input(
            babi_train_raw)
        self.test_input, self.test_q, self.test_answer, self.test_fact_count, self.test_input_mask = self._process_input(
            babi_test_raw)
        self.vocab_size = len(self.vocab)

        self.input_var = T.tensor3(
            'input_var')  # (batch_size, seq_len, glove_dim)
        self.q_var = T.tensor3('question_var')  # as self.input_var
        self.answer_var = T.ivector(
            'answer_var')  # answer of example in minibatch
        self.fact_count_var = T.ivector(
            'fact_count_var')  # number of facts in the example of minibatch
        self.input_mask_var = T.imatrix(
            'input_mask_var')  # (batch_size, indices)

        print "==> building input module"
        self.W_inp_res_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_inp_upd_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_inp_hid_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        input_var_shuffled = self.input_var.dimshuffle(1, 2, 0)
        inp_dummy = theano.shared(
            np.zeros((self.dim, self.batch_size), dtype=floatX))
        inp_c_history, _ = theano.scan(fn=self.input_gru_step,
                                       sequences=input_var_shuffled,
                                       outputs_info=T.zeros_like(inp_dummy))

        inp_c_history_shuffled = inp_c_history.dimshuffle(2, 0, 1)

        inp_c_list = []
        inp_c_mask_list = []
        for batch_index in range(self.batch_size):
            taken = inp_c_history_shuffled[batch_index].take(
                self.input_mask_var[
                    batch_index, :self.fact_count_var[batch_index]],
                axis=0)
            inp_c_list.append(
                T.concatenate([
                    taken,
                    T.zeros((self.max_fact_count - taken.shape[0], self.dim),
                            floatX)
                ]))
            inp_c_mask_list.append(
                T.concatenate([
                    T.ones((taken.shape[0], ), np.int32),
                    T.zeros((self.max_fact_count - taken.shape[0], ), np.int32)
                ]))

        self.inp_c = T.stack(inp_c_list).dimshuffle(1, 2, 0)
        inp_c_mask = T.stack(inp_c_mask_list).dimshuffle(1, 0)

        q_var_shuffled = self.q_var.dimshuffle(1, 2, 0)
        q_dummy = theano.shared(
            np.zeros((self.dim, self.batch_size), dtype=floatX))
        q_q_history, _ = theano.scan(fn=self.input_gru_step,
                                     sequences=q_var_shuffled,
                                     outputs_info=T.zeros_like(q_dummy))
        self.q_q = q_q_history[-1]

        print "==> creating parameters for memory module"
        self.W_mem_res_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_upd_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_hid_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_1 = nn_utils.normal_param(std=0.1,
                                         shape=(self.dim, 7 * self.dim + 0))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim, ))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1, ))

        print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(
                self.GRU_update(memory[iter - 1], current_episode,
                                self.W_mem_res_in, self.W_mem_res_hid,
                                self.b_mem_res, self.W_mem_upd_in,
                                self.W_mem_upd_hid, self.b_mem_upd,
                                self.W_mem_hid_in, self.W_mem_hid_hid,
                                self.b_mem_hid))

        last_mem = memory[-1]

        print "==> building answer module"
        self.W_a = nn_utils.normal_param(std=0.1,
                                         shape=(self.vocab_size, self.dim))

        if self.answer_module == 'feedforward':
            self.prediction = nn_utils.softmax(T.dot(self.W_a, last_mem))

        elif self.answer_module == 'recurrent':
            self.W_ans_res_in = nn_utils.normal_param(
                std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_res_hid = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
            self.b_ans_res = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

            self.W_ans_upd_in = nn_utils.normal_param(
                std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_upd_hid = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
            self.b_ans_upd = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

            self.W_ans_hid_in = nn_utils.normal_param(
                std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_hid_hid = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
            self.b_ans_hid = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

            def answer_step(prev_a, prev_y):
                a = self.GRU_update(prev_a, T.concatenate([prev_y, self.q_q]),
                                    self.W_ans_res_in, self.W_ans_res_hid,
                                    self.b_ans_res, self.W_ans_upd_in,
                                    self.W_ans_upd_hid, self.b_ans_upd,
                                    self.W_ans_hid_in, self.W_ans_hid_hid,
                                    self.b_ans_hid)

                y = nn_utils.softmax(T.dot(self.W_a, a))
                return [a, y]

            # TODO: add conditional ending
            dummy = theano.shared(
                np.zeros((self.vocab_size, self.batch_size), dtype=floatX))
            results, updates = theano.scan(
                fn=self.answer_step,
                outputs_info=[last_mem, T.zeros_like(dummy)],  #(last_mem, y)
                n_steps=1)
            self.prediction = results[1][-1]

        else:
            raise Exception("invalid answer_module")

        self.prediction = self.prediction.dimshuffle(1, 0)

        self.params = [
            self.W_inp_res_in,
            self.W_inp_res_hid,
            self.b_inp_res,
            self.W_inp_upd_in,
            self.W_inp_upd_hid,
            self.b_inp_upd,
            self.W_inp_hid_in,
            self.W_inp_hid_hid,
            self.b_inp_hid,
            self.W_mem_res_in,
            self.W_mem_res_hid,
            self.b_mem_res,
            self.W_mem_upd_in,
            self.W_mem_upd_hid,
            self.b_mem_upd,
            self.W_mem_hid_in,
            self.W_mem_hid_hid,
            self.b_mem_hid,  #self.W_b
            self.W_1,
            self.W_2,
            self.b_1,
            self.b_2,
            self.W_a
        ]

        if self.answer_module == 'recurrent':
            self.params = self.params + [
                self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res,
                self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid
            ]

        print "==> building loss layer and computing updates"
        self.loss_ce = T.nnet.categorical_crossentropy(self.prediction,
                                                       self.answer_var).mean()

        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0

        self.loss = self.loss_ce + self.loss_l2

        updates = lasagne.updates.adadelta(self.loss, self.params)

        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(
                inputs=[
                    self.input_var, self.q_var, self.answer_var,
                    self.fact_count_var, self.input_mask_var
                ],
                outputs=[self.prediction, self.loss],
                updates=updates)

        print "==> compiling test_fn"
        self.test_fn = theano.function(inputs=[
            self.input_var, self.q_var, self.answer_var, self.fact_count_var,
            self.input_mask_var
        ],
                                       outputs=[self.prediction, self.loss])
    def __init__(self, stories, QAs, batch_size, story_v, learning_rate,
                 word_vector_size, sent_vector_size, dim, mode, answer_module,
                 input_mask_mode, memory_hops, l2, story_source,
                 normalize_attention, batch_norm, dropout, dropout_in,
                 **kwargs):

        #print "==> not used params in DMN class:", kwargs.keys()
        self.learning_rate = learning_rate
        self.rng = np.random
        self.rng.seed(1234)
        mqa = MovieQA.DataLoader()
        ### Load Word2Vec model
        w2v_model = w2v.load(w2v_mqa_model_filename, kind='bin')
        self.w2v = w2v_model
        self.d_w2v = len(w2v_model.get_vector(w2v_model.vocab[1]))
        self.word_thresh = 1
        print "Loaded word2vec model: dim = %d | vocab-size = %d" % (
            self.d_w2v, len(w2v_model.vocab))
        ### Create vocabulary-to-index and index-to-vocabulary
        v2i = {'': 0, 'UNK': 1}  # vocabulary to index
        QA_words, v2i = self.create_vocabulary(
            QAs,
            stories,
            v2i,
            w2v_vocab=w2v_model.vocab.tolist(),
            word_thresh=self.word_thresh)
        i2v = {v: k for k, v in v2i.iteritems()}
        self.vocab = v2i
        self.ivocab = i2v
        self.story_v = story_v
        self.word2vec = w2v_model
        self.word_vector_size = word_vector_size
        self.sent_vector_size = sent_vector_size
        self.dim = dim
        self.batch_size = batch_size
        self.mode = mode
        self.answer_module = answer_module
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        self.l2 = l2
        self.normalize_attention = normalize_attention
        self.batch_norm = batch_norm
        self.dropout = dropout
        self.dropout_in = dropout_in

        #self.max_inp_sent_len = 0
        #self.max_q_len = 0

        ### Convert QAs and stories into numpy matrices (like in the bAbI data set)
        # storyM - Dictionary - indexed by imdb_key. Values are [num-sentence X max-num-words]
        # questionM - NP array - [num-question X max-num-words]
        # answerM - NP array - [num-question X num-answer-options X max-num-words]
        storyM, questionM, answerM = self.data_in_matrix_form(
            stories, QA_words, v2i)
        qinfo = self.associate_additional_QA_info(QAs)

        ### Split everything into train, val, and test data
        #train_storyM = {k:v for k, v in storyM.iteritems() if k in mqa.data_split['train']}
        #val_storyM   = {k:v for k, v in storyM.iteritems() if k in mqa.data_split['val']}
        #test_storyM  = {k:v for k, v in storyM.iteritems() if k in mqa.data_split['test']}

        def split_train_test(long_list, QAs, trnkey='train', tstkey='val'):
            # Create train/val/test splits based on key
            train_split = [
                item for k, item in enumerate(long_list)
                if QAs[k].qid.startswith('train')
            ]
            val_split = [
                item for k, item in enumerate(long_list)
                if QAs[k].qid.startswith('val')
            ]
            test_split = [
                item for k, item in enumerate(long_list)
                if QAs[k].qid.startswith('test')
            ]
            if type(long_list) == np.ndarray:
                return np.array(train_split), np.array(val_split), np.array(
                    test_split)
            else:
                return train_split, val_split, test_split

        train_questionM, val_questionM, test_questionM = split_train_test(
            questionM, QAs)
        train_answerM, val_answerM, test_answerM, = split_train_test(
            answerM, QAs)
        train_qinfo, val_qinfo, test_qinfo = split_train_test(qinfo, QAs)

        QA_train = [qa for qa in QAs if qa.qid.startswith('train:')]
        QA_val = [qa for qa in QAs if qa.qid.startswith('val:')]
        QA_test = [qa for qa in QAs if qa.qid.startswith('test:')]

        #train_data = {'s':train_storyM, 'q':train_questionM, 'a':train_answerM, 'qinfo':train_qinfo}
        #val_data =   {'s':val_storyM,   'q':val_questionM,   'a':val_answerM,   'qinfo':val_qinfo}
        #test_data  = {'s':test_storyM,  'q':test_questionM,  'a':test_answerM,  'qinfo':test_qinfo}

        with open('train_split.json') as fid:
            trdev = json.load(fid)

        s_key = self.story_v.keys()
        self.train_range = [
            k for k, qi in enumerate(qinfo)
            if (qi['movie'] in trdev['train'] and qi['qid'] in s_key)
        ]
        self.train_val_range = [
            k for k, qi in enumerate(qinfo)
            if (qi['movie'] in trdev['dev'] and qi['qid'] in s_key)
        ]
        self.val_range = [
            k for k, qi in enumerate(val_qinfo) if qi['qid'] in s_key
        ]

        self.max_sent_len = max(
            [sty.shape[0] for sty in self.story_v.values()])
        self.train_input = self.story_v
        self.train_val_input = self.story_v
        self.test_input = self.story_v
        self.train_q = train_questionM
        self.train_answer = train_answerM
        self.train_qinfo = train_qinfo
        self.train_val_q = train_questionM
        self.train_val_answer = train_answerM
        self.train_val_qinfo = train_qinfo
        self.test_q = val_questionM
        self.test_answer = val_answerM
        self.test_qinfo = val_qinfo
        """Setup some configuration parts of the model.
        """
        self.v2i = v2i
        self.vs = len(v2i)
        self.d_lproj = 300

        # define Look-Up-Table mask
        np_mask = np.vstack(
            (np.zeros(self.d_w2v), np.ones((self.vs - 1, self.d_w2v))))
        T_mask = theano.shared(np_mask.astype(theano.config.floatX),
                               name='LUT_mask')

        # setup Look-Up-Table to be Word2Vec
        self.pca_mat = None
        print "Initialize LUTs as word2vec and use linear projection layer"

        self.LUT = np.zeros((self.vs, self.d_w2v), dtype='float32')
        found_words = 0
        for w, v in self.v2i.iteritems():
            if w in self.w2v.vocab:  # all valid words are already in vocab or 'UNK'
                self.LUT[v] = self.w2v.get_vector(w)
                found_words += 1
            else:
                # LUT[v] = np.zeros((self.d_w2v))
                self.LUT[v] = self.rng.randn(self.d_w2v)
                self.LUT[v] = self.LUT[v] / (np.linalg.norm(self.LUT[v]) +
                                             1e-6)

        print "Found %d / %d words" % (found_words, len(self.v2i))

        # word 0 is blanked out, word 1 is 'UNK'
        self.LUT[0] = np.zeros((self.d_w2v))

        # if linear projection layer is not the same shape as LUT, then initialize with PCA
        if self.d_lproj != self.LUT.shape[1]:
            pca = PCA(n_components=self.d_lproj, whiten=True)
            self.pca_mat = pca.fit_transform(self.LUT.T)  # 300 x 100?

        # setup LUT!
        self.T_w2v = theano.shared(self.LUT.astype(theano.config.floatX))

        self.train_input_mask = np_mask
        self.test_input_mask = np_mask
        #self.train_input, self.train_q, self.train_answer, self.train_input_mask = self._process_input(babi_train_raw)
        #self.test_input, self.test_q, self.test_answer, self.test_input_mask = self._process_input(babi_test_raw)
        self.vocab_size = len(self.vocab)

        self.input_var = T.tensor3(
            'input_var')  # batch-size X sentences X 4096
        self.q_var = T.matrix('question_var')  # batch-size X 300
        self.answer_var = T.tensor3(
            'answer_var')  # batch-size X multiple options X 300
        self.input_mask_var = T.imatrix('input_mask_var')
        self.target = T.ivector(
            'target'
        )  # batch-size ('single': word index,  'multi_choice': correct option)
        self.attentions = []

        #self.pe_matrix_in = self.pe_matrix(self.max_inp_sent_len)
        #self.pe_matrix_q = self.pe_matrix(self.max_q_len)

        print "==> building input module"

        #positional encoder weights
        self.W_pe = nn_utils.normal_param(std=0.1,
                                          shape=(self.vocab_size, self.dim))

        #biGRU input fusion weights
        self.W_inp_res_in_fwd = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.sent_vector_size))
        self.W_inp_res_hid_fwd = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
        self.b_inp_res_fwd = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

        self.W_inp_upd_in_fwd = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.sent_vector_size))
        self.W_inp_upd_hid_fwd = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
        self.b_inp_upd_fwd = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

        self.W_inp_hid_in_fwd = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.sent_vector_size))
        self.W_inp_hid_hid_fwd = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
        self.b_inp_hid_fwd = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

        self.W_inp_res_in_bwd = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.sent_vector_size))
        self.W_inp_res_hid_bwd = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
        self.b_inp_res_bwd = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

        self.W_inp_upd_in_bwd = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.sent_vector_size))
        self.W_inp_upd_hid_bwd = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
        self.b_inp_upd_bwd = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

        self.W_inp_hid_in_bwd = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.sent_vector_size))
        self.W_inp_hid_hid_bwd = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
        self.b_inp_hid_bwd = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

        #self.V_f = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        #self.V_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))

        self.inp_sent_reps = self.input_var
        self.ans_reps = self.answer_var
        self.inp_c = self.input_module_full(self.inp_sent_reps)
        self.q_q = self.q_var

        print "==> creating parameters for memory module"
        self.W_mem_res_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_upd_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_hid_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        #self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        #self.W_1 = nn_utils.normal_param(std=0.1, shape=(self.dim, 7 * self.dim + 0))
        self.W_1 = nn_utils.normal_param(std=0.1,
                                         shape=(self.dim, 4 * self.dim + 0))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim, ))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1, ))

        print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(
                self.GRU_update(memory[iter - 1], current_episode,
                                self.W_mem_res_in, self.W_mem_res_hid,
                                self.b_mem_res, self.W_mem_upd_in,
                                self.W_mem_upd_hid, self.b_mem_upd,
                                self.W_mem_hid_in, self.W_mem_hid_hid,
                                self.b_mem_hid))

        #last_mem_raw = memory[-1].dimshuffle(('x', 0))
        last_mem_raw = memory[-1]

        net = layers.InputLayer(shape=(self.batch_size, self.dim),
                                input_var=last_mem_raw)
        if self.dropout > 0 and self.mode == 'train':
            net = layers.DropoutLayer(net, p=self.dropout)
        last_mem = layers.get_output(net)[0]

        print "==> building answer module"
        self.W_a = nn_utils.normal_param(std=0.1, shape=(300, self.dim))

        if self.answer_module == 'feedforward':
            self.temp = T.dot(self.ans_reps, self.W_a)
            self.prediction = nn_utils.softmax(T.dot(self.temp, last_mem))

        elif self.answer_module == 'recurrent':
            self.W_ans_res_in = nn_utils.normal_param(
                std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_res_hid = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
            self.b_ans_res = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

            self.W_ans_upd_in = nn_utils.normal_param(
                std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_upd_hid = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
            self.b_ans_upd = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

            self.W_ans_hid_in = nn_utils.normal_param(
                std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_hid_hid = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
            self.b_ans_hid = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

            def answer_step(prev_a, prev_y):
                a = self.GRU_update(prev_a, T.concatenate([prev_y, self.q_q]),
                                    self.W_ans_res_in, self.W_ans_res_hid,
                                    self.b_ans_res, self.W_ans_upd_in,
                                    self.W_ans_upd_hid, self.b_ans_upd,
                                    self.W_ans_hid_in, self.W_ans_hid_hid,
                                    self.b_ans_hid)

                y = nn_utils.softmax(T.dot(self.W_a, a))
                return [a, y]

            # add conditional ending?
            dummy = theano.shared(np.zeros((self.vocab_size, ), dtype=floatX))

            results, updates = theano.scan(
                fn=answer_step,
                outputs_info=[last_mem, T.zeros_like(dummy)],
                n_steps=1)
            self.prediction = results[1][-1]

        else:
            raise Exception("invalid answer_module")

        print "==> collecting all parameters"
        self.params = [
            self.W_pe,
            self.W_inp_res_in_fwd,
            self.W_inp_res_hid_fwd,
            self.b_inp_res_fwd,
            self.W_inp_upd_in_fwd,
            self.W_inp_upd_hid_fwd,
            self.b_inp_upd_fwd,
            self.W_inp_hid_in_fwd,
            self.W_inp_hid_hid_fwd,
            self.b_inp_hid_fwd,
            self.W_inp_res_in_bwd,
            self.W_inp_res_hid_bwd,
            self.b_inp_res_bwd,
            self.W_inp_upd_in_bwd,
            self.W_inp_upd_hid_bwd,
            self.b_inp_upd_bwd,
            self.W_inp_hid_in_bwd,
            self.W_inp_hid_hid_bwd,
            self.b_inp_hid_bwd,
            self.W_mem_res_in,
            self.W_mem_res_hid,
            self.b_mem_res,
            self.W_mem_upd_in,
            self.W_mem_upd_hid,
            self.b_mem_upd,
            self.W_mem_hid_in,
            self.W_mem_hid_hid,
            self.b_mem_hid,  #self.W_b
            self.W_1,
            self.W_2,
            self.b_1,
            self.b_2,
            self.W_a
        ]

        if self.answer_module == 'recurrent':
            self.params = self.params + [
                self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res,
                self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid
            ]

        print "==> building loss layer and computing updates"
        #tmp= self.prediction.dimshuffle(2,0,1)
        #res, _ =theano.scan(fn = lambda inp: inp, sequences=tmp)
        #self.prediction = res[-1]
        self.loss_ce = T.nnet.categorical_crossentropy(self.prediction,
                                                       self.target)

        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0

        self.loss = T.mean(self.loss_ce) + self.loss_l2

        #updates = lasagne.updates.adadelta(self.loss, self.params)
        #updates = lasagne.updates.adam(self.loss, self.params)
        updates = lasagne.updates.adam(self.loss,
                                       self.params,
                                       learning_rate=self.learning_rate,
                                       beta1=0.5)  #from DCGAN paper
        #updates = lasagne.updates.adadelta(self.loss, self.params, learning_rate=0.0005)
        #updates = lasagne.updates.momentum(self.loss, self.params, learning_rate=0.0003)

        self.attentions = T.stack(self.attentions)
        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(
                inputs=[
                    self.input_var, self.q_var, self.answer_var, self.target
                ],
                outputs=[self.prediction, self.loss, self.attentions],
                updates=updates,
                on_unused_input='warn',
                allow_input_downcast=True)

        print "==> compiling test_fn"
        self.test_fn = theano.function(
            inputs=[self.input_var, self.q_var, self.answer_var, self.target],
            outputs=[self.prediction, self.loss, self.attentions],
            on_unused_input='warn',
            allow_input_downcast=True)
    def __init__(self, babi_train_raw, babi_test_raw, word2vec, word_vector_size, sent_vector_size, 
                dim, mode, answer_module, input_mask_mode, memory_hops, l2, 
                normalize_attention, batch_norm, dropout, dropout_in, **kwargs):

        print "==> not used params in DMN class:", kwargs.keys()
        self.vocab = {None: 0}
        self.ivocab = {0: None}
        
        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.sent_vector_size = sent_vector_size
        self.dim = dim
        self.mode = mode
        self.answer_module = answer_module
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        self.l2 = l2
        self.normalize_attention = normalize_attention
        self.batch_norm = batch_norm
        self.dropout = dropout
        self.dropout_in = dropout_in

        self.max_inp_sent_len = 0
        self.max_q_len = 0

        """
        #To Use All Vocab
        self.vocab = {None: 0, 'jason': 134.0, 'office': 14.0, 'yellow': 78.0, 'bedroom': 24.0, 'go': 108.0, 'yes': 15.0, 'antoine': 138.0, 'milk': 139.0, 'before': 46.0, 'grabbed': 128.0, 'fit': 100.0, 'how': 105.0, 'swan': 73.0, 'than': 96.0, 'to': 13.0, 'does': 99.0, 's,e': 110.0, 'east': 102.0, 'rectangle': 82.0, 'gave': 149.0, 'then': 39.0, 'evening': 48.0, 'triangle': 79.0, 'garden': 37.0, 'get': 131.0, 'football,apple,milk': 179.0, 'they': 41.0, 'not': 178.0, 'bigger': 95.0, 'gray': 77.0, 'school': 6.0, 'apple': 142.0, 'did': 127.0, 'morning': 44.0, 'discarded': 146.0, 'julius': 72.0, 'she': 29.0, 'went': 11.0, 'where': 30.0, 'jeff': 152.0, 'square': 84.0, 'who': 153.0, 'tired': 124.0, 'there': 130.0, 'back': 12.0, 'lion': 70.0, 'are': 50.0, 'picked': 143.0, 'e,e': 119.0, 'pajamas': 129.0, 'Mary': 157.0, 'blue': 83.0, 'what': 63.0, 'container': 98.0, 'rhino': 76.0, 'daniel': 31.0, 'bernhard': 67.0, 'milk,football': 172.0, 'above': 80.0, 'got': 136.0, 'emily': 60.0, 'red': 88.0, 'either': 3.0, 'sheep': 58.0, 'football': 137.0, 'jessica': 61.0, 'do': 106.0, 'Bill': 155.0, 'football,apple': 168.0, 'fred': 1.0, 'winona': 59.0, 'objects': 161.0, 'put': 147.0, 'kitchen': 17.0, 'box': 90.0, 'received': 154.0, 'journeyed': 25.0, 'of': 52.0, 'wolf': 62.0, 'afternoon': 47.0, 'or': 7.0, 'south': 112.0, 's,w': 114.0, 'afterwards': 32.0, 'sumit': 123.0, 'color': 75.0, 'julie': 23.0, 'one': 163.0, 'down': 148.0, 'nothing': 167.0, 'n,n': 113.0, 'right': 86.0, 's,s': 116.0, 'gertrude': 54.0, 'bathroom': 26.0, 'from': 109.0, 'west': 104.0, 'chocolates': 91.0, 'two': 165.0, 'frog': 66.0, '.': 9.0, 'cats': 57.0, 'apple,milk,football': 175.0, 'passed': 158.0, 'apple,football,milk': 176.0, 'white': 71.0, 'john': 35.0, 'was': 45.0, 'mary': 10.0, 'apple,football': 170.0, 'north': 103.0, 'n,w': 111.0, 'that': 28.0, 'park': 8.0, 'took': 141.0, 'chocolate': 101.0, 'carrying': 162.0, 'n,e': 120.0, 'mice': 49.0, 'travelled': 22.0, 'he': 33.0, 'none': 164.0, 'bored': 133.0, 'e,n': 117.0, None: 0, 'Jeff': 159.0, 'this': 43.0, 'inside': 93.0, 'bill': 16.0, 'up': 144.0, 'cat': 64.0, 'will': 125.0, 'below': 87.0, 'greg': 74.0, 'three': 166.0, 'suitcase': 97.0, 'following': 36.0, 'e,s': 115.0, 'and': 40.0, 'thirsty': 135.0, 'cinema': 19.0, 'is': 2.0, 'moved': 18.0, 'yann': 132.0, 'sphere': 89.0, 'dropped': 145.0, 'in': 4.0, 'mouse': 56.0, 'football,milk': 171.0, 'pink': 81.0, 'afraid': 51.0, 'no': 20.0, 'Fred': 156.0, 'w,s': 121.0, 'handed': 151.0, 'w,w': 118.0, 'brian': 69.0, 'chest': 94.0, 'w,n': 122.0, 'you': 107.0, 'many': 160.0, 'lily': 65.0, 'hallway': 34.0, 'why': 126.0, 'after': 27.0, 'yesterday': 42.0, 'sandra': 38.0, 'fits': 92.0, 'milk,football,apple': 173.0, 'the': 5.0, 'milk,apple': 169.0, 'a': 55.0, 'give': 150.0, 'longer': 177.0, 'maybe': 21.0, 'hungry': 140.0, 'apple,milk': 174.0, 'green': 68.0, 'wolves': 53.0, 'left': 85.0}
        self.ivocab = {0: None, 1: 'fred', 2: 'is', 3: 'either', 4: 'in', 5: 'the', 6: 'school', 7: 'or', 8: 'park', 9: '.', 10: 'mary', 11: 'went', 12: 'back', 13: 'to', 14: 'office', 15: 'yes', 16: 'bill', 17: 'kitchen', 18: 'moved', 19: 'cinema', 20: 'no', 21: 'maybe', 22: 'travelled', 23: 'julie', 24: 'bedroom', 25: 'journeyed', 26: 'bathroom', 27: 'after', 28: 'that', 29: 'she', 30: 'where', 31: 'daniel', 32: 'afterwards', 33: 'he', 34: 'hallway', 35: 'john', 36: 'following', 37: 'garden', 38: 'sandra', 39: 'then', 40: 'and', 41: 'they', 42: 'yesterday', 43: 'this', 44: 'morning', 45: 'was', 46: 'before', 47: 'afternoon', 48: 'evening', 49: 'mice', 50: 'are', 51: 'afraid', 52: 'of', 53: 'wolves', 54: 'gertrude', 55: 'a', 56: 'mouse', 57: 'cats', 58: 'sheep', 59: 'winona', 60: 'emily', 61: 'jessica', 62: 'wolf', 63: 'what', 64: 'cat', 65: 'lily', 66: 'frog', 67: 'bernhard', 68: 'green', 69: 'brian', 70: 'lion', 71: 'white', 72: 'julius', 73: 'swan', 74: 'greg', 75: 'color', 76: 'rhino', 77: 'gray', 78: 'yellow', 79: 'triangle', 80: 'above', 81: 'pink', 82: 'rectangle', 83: 'blue', 84: 'square', 85: 'left', 86: 'right', 87: 'below', 88: 'red', 89: 'sphere', 90: 'box', 91: 'chocolates', 92: 'fits', 93: 'inside', 94: 'chest', 95: 'bigger', 96: 'than', 97: 'suitcase', 98: 'container', 99: 'does', 100: 'fit', 101: 'chocolate', 102: 'east', 103: 'north', 104: 'west', 105: 'how', 106: 'do', 107: 'you', 108: 'go', 109: 'from', 110: 's,e', 111: 'n,w', 112: 'south', 113: 'n,n', 114: 's,w', 115: 'e,s', 116: 's,s', 117: 'e,n', 118: 'w,w', 119: 'e,e', 120: 'n,e', 121: 'w,s', 122: 'w,n', 123: 'sumit', 124: 'tired', 125: 'will', 126: 'why', 127: 'did', 128: 'grabbed', 129: 'pajamas', 130: 'there', 131: 'get', 132: 'yann', 133: 'bored', 134: 'jason', 135: 'thirsty', 136: 'got', 137: 'football', 138: 'antoine', 139: 'milk', 140: 'hungry', 141: 'took', 142: 'apple', 143: 'picked', 144: 'up', 145: 'dropped', 146: 'discarded', 147: 'put', 148: 'down', 149: 'gave', 150: 'give', 151: 'handed', 152: 'jeff', 153: 'who', 154: 'received', 155: 'Bill', 156: 'Fred', 157: 'Mary', 158: 'passed', 159: 'Jeff', 160: 'many', 161: 'objects', 162: 'carrying', 163: 'one', 164: 'none', 165: 'two', 166: 'three', 167: 'nothing', 168: 'football,apple', 169: 'milk,apple', 170: 'apple,football', 171: 'football,milk', 172: 'milk,football', 173: 'milk,football,apple', 174: 'apple,milk', 175: 'apple,milk,football', 176: 'apple,football,milk', 177: 'longer', 178: 'not', 179: 'football,apple,milk'}
        #self.vocab = {'jason': 134.0, 'office': 14.0, 'yellow': 78.0, 'bedroom': 24.0, 'go': 108.0, 'yes': 15.0, 'antoine': 138.0, 'milk': 139.0, 'before': 46.0, 'grabbed': 128.0, 'fit': 100.0, 'how': 105.0, 'swan': 73.0, 'than': 96.0, 'to': 13.0, 'does': 99.0, 's,e': 110.0, 'east': 102.0, 'rectangle': 82.0, 'gave': 149.0, 'then': 39.0, 'evening': 48.0, 'triangle': 79.0, 'garden': 37.0, 'get': 131.0, 'football,apple,milk': 179.0, 'they': 41.0, 'not': 178.0, 'bigger': 95.0, 'gray': 77.0, 'school': 6.0, 'apple': 142.0, 'did': 127.0, 'morning': 44.0, 'discarded': 146.0, 'julius': 72.0, 'she': 29.0, 'went': 11.0, 'where': 30.0, 'jeff': 152.0, 'square': 84.0, 'who': 153.0, 'tired': 124.0, 'there': 130.0, 'back': 12.0, 'lion': 70.0, 'are': 50.0, 'picked': 143.0, 'e,e': 119.0, 'pajamas': 129.0, 'Mary': 157.0, 'blue': 83.0, 'what': 63.0, 'container': 98.0, 'rhino': 76.0, 'daniel': 31.0, 'bernhard': 67.0, 'milk,football': 172.0, 'above': 80.0, 'got': 136.0, 'emily': 60.0, 'red': 88.0, 'either': 3.0, 'sheep': 58.0, 'football': 137.0, 'jessica': 61.0, 'do': 106.0, 'Bill': 155.0, 'football,apple': 168.0, 'fred': 1.0, 'winona': 59.0, 'objects': 161.0, 'put': 147.0, 'kitchen': 17.0, 'box': 90.0, 'received': 154.0, 'journeyed': 25.0, 'of': 52.0, 'wolf': 62.0, 'afternoon': 47.0, 'or': 7.0, 'south': 112.0, 's,w': 114.0, 'afterwards': 32.0, 'sumit': 123.0, 'color': 75.0, 'julie': 23.0, 'one': 163.0, 'down': 148.0, 'nothing': 167.0, 'n,n': 113.0, 'right': 86.0, 's,s': 116.0, 'gertrude': 54.0, 'bathroom': 26.0, 'from': 109.0, 'west': 104.0, 'chocolates': 91.0, 'two': 165.0, 'frog': 66.0, '.': 9.0, 'cats': 57.0, 'apple,milk,football': 175.0, 'passed': 158.0, 'apple,football,milk': 176.0, 'white': 71.0, 'john': 35.0, 'was': 45.0, 'mary': 10.0, 'apple,football': 170.0, 'north': 103.0, 'n,w': 111.0, 'that': 28.0, 'park': 8.0, 'took': 141.0, 'chocolate': 101.0, 'carrying': 162.0, 'n,e': 120.0, 'mice': 49.0, 'travelled': 22.0, 'he': 33.0, 'none': 164.0, 'bored': 133.0, 'e,n': 117.0, None: 0, 'Jeff': 159.0, 'this': 43.0, 'inside': 93.0, 'bill': 16.0, 'up': 144.0, 'cat': 64.0, 'will': 125.0, 'below': 87.0, 'greg': 74.0, 'three': 166.0, 'suitcase': 97.0, 'following': 36.0, 'e,s': 115.0, 'and': 40.0, 'thirsty': 135.0, 'cinema': 19.0, 'is': 2.0, 'moved': 18.0, 'yann': 132.0, 'sphere': 89.0, 'dropped': 145.0, 'in': 4.0, 'mouse': 56.0, 'football,milk': 171.0, 'pink': 81.0, 'afraid': 51.0, 'no': 20.0, 'Fred': 156.0, 'w,s': 121.0, 'handed': 151.0, 'w,w': 118.0, 'brian': 69.0, 'chest': 94.0, 'w,n': 122.0, 'you': 107.0, 'many': 160.0, 'lily': 65.0, 'hallway': 34.0, 'why': 126.0, 'after': 27.0, 'yesterday': 42.0, 'sandra': 38.0, 'fits': 92.0, 'milk,football,apple': 173.0, 'the': 5.0, 'milk,apple': 169.0, 'a': 55.0, 'give': 150.0, 'longer': 177.0, 'maybe': 21.0, 'hungry': 140.0, 'apple,milk': 174.0, 'green': 68.0, 'wolves': 53.0, 'left': 85.0}
        #self.ivocab = {1: 'fred', 2: 'is', 3: 'either', 4: 'in', 5: 'the', 6: 'school', 7: 'or', 8: 'park', 9: '.', 10: 'mary', 11: 'went', 12: 'back', 13: 'to', 14: 'office', 15: 'yes', 16: 'bill', 17: 'kitchen', 18: 'moved', 19: 'cinema', 20: 'no', 21: 'maybe', 22: 'travelled', 23: 'julie', 24: 'bedroom', 25: 'journeyed', 26: 'bathroom', 27: 'after', 28: 'that', 29: 'she', 30: 'where', 31: 'daniel', 32: 'afterwards', 33: 'he', 34: 'hallway', 35: 'john', 36: 'following', 37: 'garden', 38: 'sandra', 39: 'then', 40: 'and', 41: 'they', 42: 'yesterday', 43: 'this', 44: 'morning', 45: 'was', 46: 'before', 47: 'afternoon', 48: 'evening', 49: 'mice', 50: 'are', 51: 'afraid', 52: 'of', 53: 'wolves', 54: 'gertrude', 55: 'a', 56: 'mouse', 57: 'cats', 58: 'sheep', 59: 'winona', 60: 'emily', 61: 'jessica', 62: 'wolf', 63: 'what', 64: 'cat', 65: 'lily', 66: 'frog', 67: 'bernhard', 68: 'green', 69: 'brian', 70: 'lion', 71: 'white', 72: 'julius', 73: 'swan', 74: 'greg', 75: 'color', 76: 'rhino', 77: 'gray', 78: 'yellow', 79: 'triangle', 80: 'above', 81: 'pink', 82: 'rectangle', 83: 'blue', 84: 'square', 85: 'left', 86: 'right', 87: 'below', 88: 'red', 89: 'sphere', 90: 'box', 91: 'chocolates', 92: 'fits', 93: 'inside', 94: 'chest', 95: 'bigger', 96: 'than', 97: 'suitcase', 98: 'container', 99: 'does', 100: 'fit', 101: 'chocolate', 102: 'east', 103: 'north', 104: 'west', 105: 'how', 106: 'do', 107: 'you', 108: 'go', 109: 'from', 110: 's,e', 111: 'n,w', 112: 'south', 113: 'n,n', 114: 's,w', 115: 'e,s', 116: 's,s', 117: 'e,n', 118: 'w,w', 119: 'e,e', 120: 'n,e', 121: 'w,s', 122: 'w,n', 123: 'sumit', 124: 'tired', 125: 'will', 126: 'why', 127: 'did', 128: 'grabbed', 129: 'pajamas', 130: 'there', 131: 'get', 132: 'yann', 133: 'bored', 134: 'jason', 135: 'thirsty', 136: 'got', 137: 'football', 138: 'antoine', 139: 'milk', 140: 'hungry', 141: 'took', 142: 'apple', 143: 'picked', 144: 'up', 145: 'dropped', 146: 'discarded', 147: 'put', 148: 'down', 149: 'gave', 150: 'give', 151: 'handed', 152: 'jeff', 153: 'who', 154: 'received', 155: 'Bill', 156: 'Fred', 157: 'Mary', 158: 'passed', 159: 'Jeff', 160: 'many', 161: 'objects', 162: 'carrying', 163: 'one', 164: 'none', 165: 'two', 166: 'three', 167: 'nothing', 168: 'football,apple', 169: 'milk,apple', 170: 'apple,football', 171: 'football,milk', 172: 'milk,football', 173: 'milk,football,apple', 174: 'apple,milk', 175: 'apple,milk,football', 176: 'apple,football,milk', 177: 'longer', 178: 'not', 179: 'football,apple,milk'}
        #"""
        
        self.train_input, self.train_q, self.train_answer, self.train_input_mask = self._process_input(babi_train_raw)
        self.test_input, self.test_q, self.test_answer, self.test_input_mask = self._process_input(babi_test_raw)
        self.vocab_size = len(self.vocab)

        self.input_var = T.imatrix('input_var')
        self.q_var = T.ivector('question_var')
        self.answer_var = T.iscalar('answer_var')
        self.input_mask_var = T.ivector('input_mask_var')
        
        self.attentions = []

        self.pe_matrix_in = self.pe_matrix(self.max_inp_sent_len)
        self.pe_matrix_q = self.pe_matrix(self.max_q_len)

            
        print "==> building input module"

        #positional encoder weights
        self.W_pe = nn_utils.normal_param(std=0.1, shape=(self.vocab_size, self.dim))

        #biGRU input fusion weights
        self.W_inp_res_in_fwd = nn_utils.normal_param(std=0.1, shape=(self.dim, self.sent_vector_size))
        self.W_inp_res_hid_fwd = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_res_fwd = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_inp_upd_in_fwd = nn_utils.normal_param(std=0.1, shape=(self.dim, self.sent_vector_size))
        self.W_inp_upd_hid_fwd = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_upd_fwd = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_inp_hid_in_fwd = nn_utils.normal_param(std=0.1, shape=(self.dim, self.sent_vector_size))
        self.W_inp_hid_hid_fwd = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_hid_fwd = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_inp_res_in_bwd = nn_utils.normal_param(std=0.1, shape=(self.dim, self.sent_vector_size))
        self.W_inp_res_hid_bwd = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_res_bwd = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_inp_upd_in_bwd = nn_utils.normal_param(std=0.1, shape=(self.dim, self.sent_vector_size))
        self.W_inp_upd_hid_bwd = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_upd_bwd = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_inp_hid_in_bwd = nn_utils.normal_param(std=0.1, shape=(self.dim, self.sent_vector_size))
        self.W_inp_hid_hid_bwd = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_hid_bwd = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        #self.V_f = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        #self.V_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))

        self.inp_sent_reps, _ = theano.scan(
                                fn=self.sum_pos_encodings_in,
                                sequences=self.input_var)

        self.inp_sent_reps_stacked = T.stacklists(self.inp_sent_reps)
        #self.inp_c = self.input_module_full(self.inp_sent_reps_stacked)

        self.inp_c = self.input_module_full(self.inp_sent_reps)

        self.q_q = self.sum_pos_encodings_q(self.q_var)
                
        print "==> creating parameters for memory module"
        self.W_mem_res_in = nn_utils.normal_param(std=0.1, shape=(self.memory_hops, self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1, shape=(self.memory_hops, self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.memory_hops, self.dim,))
        
        self.W_mem_upd_in = nn_utils.normal_param(std=0.1, shape=(self.memory_hops, self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.memory_hops, self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.memory_hops, self.dim,))
        
        self.W_mem_hid_in = nn_utils.normal_param(std=0.1, shape=(self.memory_hops, self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.memory_hops, self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.memory_hops, self.dim,))
        
        #self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        #self.W_1 = nn_utils.normal_param(std=0.1, shape=(self.dim, 7 * self.dim + 0))
        self.W_1 = nn_utils.normal_param(std=0.1, shape=(self.memory_hops, self.dim, 4 * self.dim + 0))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(self.memory_hops, 1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.memory_hops, self.dim,))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(self.memory_hops, 1,))


        print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            self.mem_weight_num = int(iter - 1)
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(self.GRU_update(memory[iter - 1], current_episode,
                                          self.W_mem_res_in[self.mem_weight_num], self.W_mem_res_hid[self.mem_weight_num], self.b_mem_res[self.mem_weight_num], 
                                          self.W_mem_upd_in[self.mem_weight_num], self.W_mem_upd_hid[self.mem_weight_num], self.b_mem_upd[self.mem_weight_num],
                                          self.W_mem_hid_in[self.mem_weight_num], self.W_mem_hid_hid[self.mem_weight_num], self.b_mem_hid[self.mem_weight_num]))
        
        last_mem_raw = memory[-1].dimshuffle(('x', 0))
        
        net = layers.InputLayer(shape=(1, self.dim), input_var=last_mem_raw)
        if self.dropout > 0 and self.mode == 'train':
            net = layers.DropoutLayer(net, p=self.dropout)
        last_mem = layers.get_output(net)[0]
        
        print "==> building answer module"
        self.W_a = nn_utils.normal_param(std=0.1, shape=(self.vocab_size, self.dim))
        
        if self.answer_module == 'feedforward':
            self.prediction = nn_utils.softmax(T.dot(self.W_a, last_mem))
        
        elif self.answer_module == 'recurrent':
            self.W_ans_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
            self.b_ans_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))
            
            self.W_ans_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
            self.b_ans_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))
            
            self.W_ans_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
            self.b_ans_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
            def answer_step(prev_a, prev_y):
                a = self.GRU_update(prev_a, T.concatenate([prev_y, self.q_q]),
                                  self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res, 
                                  self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                                  self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid)
                
                y = nn_utils.softmax(T.dot(self.W_a, a))
                return [a, y]
            
            # add conditional ending?
            dummy = theano.shared(np.zeros((self.vocab_size, ), dtype=floatX))
            
            results, updates = theano.scan(fn=answer_step,
                outputs_info=[last_mem, T.zeros_like(dummy)],
                n_steps=1)
            self.prediction = results[1][-1]
        
        else:
            raise Exception("invalid answer_module")
        
        
        print "==> collecting all parameters"
        self.params = [self.W_pe,
                  self.W_inp_res_in_fwd, self.W_inp_res_hid_fwd, self.b_inp_res_fwd, 
                  self.W_inp_upd_in_fwd, self.W_inp_upd_hid_fwd, self.b_inp_upd_fwd,
                  self.W_inp_hid_in_fwd, self.W_inp_hid_hid_fwd, self.b_inp_hid_fwd,
                  self.W_inp_res_in_bwd, self.W_inp_res_hid_bwd, self.b_inp_res_bwd, 
                  self.W_inp_upd_in_bwd, self.W_inp_upd_hid_bwd, self.b_inp_upd_bwd,
                  self.W_inp_hid_in_bwd, self.W_inp_hid_hid_bwd, self.b_inp_hid_bwd, 
                  self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res, 
                  self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                  self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid, #self.W_b
                  self.W_1, self.W_2, self.b_1, self.b_2, self.W_a]

        if self.answer_module == 'recurrent':
            self.params = self.params + [self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res, 
                              self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                              self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid]
        
        
        print "==> building loss layer and computing updates"
        self.loss_ce = T.nnet.categorical_crossentropy(self.prediction.dimshuffle('x', 0), 
                                                       T.stack([self.answer_var]))[0]

        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0
        
        self.loss = self.loss_ce + self.loss_l2
        
        #updates = lasagne.updates.adadelta(self.loss, self.params)
        updates = lasagne.updates.adam(self.loss, self.params)
        updates = lasagne.updates.adam(self.loss, self.params, learning_rate=0.0001, beta1=0.5) #from DCGAN paper
        #updates = lasagne.updates.adadelta(self.loss, self.params, learning_rate=0.0005)
        #updates = lasagne.updates.momentum(self.loss, self.params, learning_rate=0.0003)
        
        self.attentions = T.stack(self.attentions)
        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.input_mask_var], 
                                            outputs=[self.prediction, self.loss, self.attentions],
                                            updates=updates,
                                            on_unused_input='warn',
                                            allow_input_downcast=True)
        
        print "==> compiling test_fn"
        self.test_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.input_mask_var],
                                       outputs=[self.prediction, self.loss, self.attentions],
                                       on_unused_input='warn',
                                       allow_input_downcast=True)
    def __init__(self, babi_train_raw, babi_test_raw, word2vec, word_vector_size, 
                dim, mode, answer_module, input_mask_mode, memory_hops, l2, 
                normalize_attention, **kwargs):

        print "==> not used params in DMN class:", kwargs.keys()
        self.vocab = {}
        self.ivocab = {}
        
        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.mode = mode
        self.answer_module = answer_module
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        self.l2 = l2
        self.normalize_attention = normalize_attention
        
        self.train_input, self.train_q, self.train_answer, self.train_input_mask = self._process_input(babi_train_raw)
        self.test_input, self.test_q, self.test_answer, self.test_input_mask = self._process_input(babi_test_raw)
        self.vocab_size = len(self.vocab)

        self.input_var = T.matrix('input_var')
        self.q_var = T.matrix('question_var')
        self.answer_var = T.iscalar('answer_var')
        self.input_mask_var = T.ivector('input_mask_var')
        
            
        print "==> building input module"
        self.W_inp_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_inp_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_inp_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        inp_c_history, _ = theano.scan(fn=self.input_gru_step, 
                    sequences=self.input_var,
                    outputs_info=T.zeros_like(self.b_inp_hid))
        
        self.inp_c = inp_c_history.take(self.input_mask_var, axis=0)
        
        self.q_q, _ = theano.scan(fn=self.input_gru_step, 
                    sequences=self.q_var,
                    outputs_info=T.zeros_like(self.b_inp_hid))

        self.q_q = self.q_q[-1]
        
        
        print "==> creating parameters for memory module"
        self.W_mem_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_mem_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_mem_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_1 = nn_utils.normal_param(std=0.1, shape=(self.dim, 7 * self.dim + 2))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1,))
        

        print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(self.GRU_update(memory[iter - 1], current_episode,
                                          self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res, 
                                          self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                                          self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid))
        
        last_mem = memory[-1]
        
        print "==> building answer module"
        self.W_a = nn_utils.normal_param(std=0.1, shape=(self.vocab_size, self.dim))
        
        if self.answer_module == 'feedforward':
            self.prediction = nn_utils.softmax(T.dot(self.W_a, last_mem))
        
        elif self.answer_module == 'recurrent':
            self.W_ans_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
            self.b_ans_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))
            
            self.W_ans_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
            self.b_ans_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))
            
            self.W_ans_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
            self.b_ans_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
            def answer_step(prev_a, prev_y):
                a = self.GRU_update(prev_a, T.concatenate([prev_y, self.q_q]),
                                  self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res, 
                                  self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                                  self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid)
                
                y = nn_utils.softmax(T.dot(self.W_a, a))
                return [a, y]
            
            # TODO: add conditional ending
            dummy = theano.shared(np.zeros((self.vocab_size, ), dtype=floatX))
            results, updates = theano.scan(fn=answer_step,
                outputs_info=[last_mem, T.zeros_like(dummy)],
                n_steps=1)
            self.prediction = results[1][-1]
        
        else:
            raise Exception("invalid answer_module")
        
        
        print "==> collecting all parameters"
        self.params = [self.W_inp_res_in, self.W_inp_res_hid, self.b_inp_res, 
                  self.W_inp_upd_in, self.W_inp_upd_hid, self.b_inp_upd,
                  self.W_inp_hid_in, self.W_inp_hid_hid, self.b_inp_hid,
                  self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res, 
                  self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                  self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid,
                  self.W_b, self.W_1, self.W_2, self.b_1, self.b_2, self.W_a]
        
        if self.answer_module == 'recurrent':
            self.params = self.params + [self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res, 
                              self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                              self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid]
        
        
        print "==> building loss layer and computing updates"
        self.loss_ce = T.nnet.categorical_crossentropy(self.prediction.dimshuffle('x', 0), T.stack([self.answer_var]))[0]
        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0
        
        self.loss = self.loss_ce + self.loss_l2
        
        updates = lasagne.updates.adadelta(self.loss, self.params)
        
        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.input_mask_var], 
                                       outputs=[self.prediction, self.loss],
                                       updates=updates)
        
        print "==> compiling test_fn"
        self.test_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.input_mask_var],
                                  outputs=[self.prediction, self.loss, self.inp_c, self.q_q, last_mem])
        
        
        if self.mode == 'train':
            print "==> computing gradients (for debugging)"
            gradient = T.grad(self.loss, self.params)
            self.get_gradient_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.input_mask_var], outputs=gradient)
    def __init__(self, babi_train_raw, babi_test_raw, word2vec, word_vector_size, dim,
                mode, answer_module, input_mask_mode, memory_hops, batch_size, l2,
                normalize_attention, batch_norm, dropout, **kwargs):

        print "==> not used params in DMN class:", kwargs.keys()

        self.vocab = {}
        self.ivocab = {}

        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.mode = mode
        self.answer_module = answer_module
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        self.batch_size = batch_size
        self.l2 = l2
        self.normalize_attention = normalize_attention
        self.batch_norm = batch_norm
        self.dropout = dropout

        self.train_input, self.train_q, self.train_answer, self.train_fact_count, self.train_input_mask = self._process_input(babi_train_raw)
        self.test_input, self.test_q, self.test_answer, self.test_fact_count, self.test_input_mask = self._process_input(babi_test_raw)
        self.vocab_size = len(self.vocab)

        self.input_var = T.tensor3('input_var') # (batch_size, seq_len, glove_dim)
        self.q_var = T.tensor3('question_var') # as self.input_var
        self.answer_var = T.ivector('answer_var') # answer of example in minibatch
        self.fact_count_var = T.ivector('fact_count_var') # number of facts in the example of minibatch
        self.input_mask_var = T.imatrix('input_mask_var') # (batch_size, indices)

        print "==> building input module"
        self.W_inp_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_inp_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_inp_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        input_var_shuffled = self.input_var.dimshuffle(1, 2, 0)
        inp_dummy = theano.shared(np.zeros((self.dim, self.batch_size), dtype=floatX))
        inp_c_history, _ = theano.scan(fn=self.input_gru_step,
                            sequences=input_var_shuffled,
                            outputs_info=T.zeros_like(inp_dummy))

        inp_c_history_shuffled = inp_c_history.dimshuffle(2, 0, 1)

        inp_c_list = []
        inp_c_mask_list = []
        for batch_index in range(self.batch_size):
            taken = inp_c_history_shuffled[batch_index].take(self.input_mask_var[batch_index, :self.fact_count_var[batch_index]], axis=0)
            inp_c_list.append(T.concatenate([taken, T.zeros((self.input_mask_var.shape[1] - taken.shape[0], self.dim), floatX)]))
            inp_c_mask_list.append(T.concatenate([T.ones((taken.shape[0],), np.int32), T.zeros((self.input_mask_var.shape[1] - taken.shape[0],), np.int32)]))

        self.inp_c = T.stack(inp_c_list).dimshuffle(1, 2, 0)
        inp_c_mask = T.stack(inp_c_mask_list).dimshuffle(1, 0)

        q_var_shuffled = self.q_var.dimshuffle(1, 2, 0)
        q_dummy = theano.shared(np.zeros((self.dim, self.batch_size), dtype=floatX))
        q_q_history, _ = theano.scan(fn=self.input_gru_step,
                            sequences=q_var_shuffled,
                            outputs_info=T.zeros_like(q_dummy))
        self.q_q = q_q_history[-1]


        print "==> creating parameters for memory module"
        self.W_mem_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_mem_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_mem_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_1 = nn_utils.normal_param(std=0.1, shape=(self.dim, 7 * self.dim + 0))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1,))


        print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(self.GRU_update(memory[iter - 1], current_episode,
                                          self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res,
                                          self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                                          self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid))

        last_mem_raw = memory[-1].dimshuffle((1, 0))

        net = layers.InputLayer(shape=(self.batch_size, self.dim), input_var=last_mem_raw)
        if self.batch_norm:
            net = layers.BatchNormLayer(incoming=net)
        if self.dropout > 0 and self.mode == 'train':
            net = layers.DropoutLayer(net, p=self.dropout)
        last_mem = layers.get_output(net).dimshuffle((1, 0))


        print "==> building answer module"
        self.W_a = nn_utils.normal_param(std=0.1, shape=(self.vocab_size, self.dim))

        if self.answer_module == 'feedforward':
            self.prediction = nn_utils.softmax(T.dot(self.W_a, last_mem))

        elif self.answer_module == 'recurrent':
            self.W_ans_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
            self.b_ans_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))

            self.W_ans_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
            self.b_ans_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))

            self.W_ans_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
            self.b_ans_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))

            def answer_step(prev_a, prev_y):
                a = self.GRU_update(prev_a, T.concatenate([prev_y, self.q_q]),
                                  self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res,
                                  self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                                  self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid)

                y = nn_utils.softmax(T.dot(self.W_a, a))
                return [a, y]

            # TODO: add conditional ending
            dummy = theano.shared(np.zeros((self.vocab_size, self.batch_size), dtype=floatX))
            results, updates = theano.scan(fn=self.answer_step,
                outputs_info=[last_mem, T.zeros_like(dummy)], #(last_mem, y)
                n_steps=1)
            self.prediction = results[1][-1]

        else:
            raise Exception("invalid answer_module")

        self.prediction = self.prediction.dimshuffle(1, 0)

        self.params = [self.W_inp_res_in, self.W_inp_res_hid, self.b_inp_res,
                  self.W_inp_upd_in, self.W_inp_upd_hid, self.b_inp_upd,
                  self.W_inp_hid_in, self.W_inp_hid_hid, self.b_inp_hid,
                  self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res,
                  self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                  self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid, #self.W_b
                  self.W_1, self.W_2, self.b_1, self.b_2, self.W_a]

        if self.answer_module == 'recurrent':
            self.params = self.params + [self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res,
                              self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                              self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid]


        print "==> building loss layer and computing updates"
        self.loss_ce = T.nnet.categorical_crossentropy(self.prediction, self.answer_var).mean()

        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0

        self.loss = self.loss_ce + self.loss_l2

        updates = lasagne.updates.adadelta(self.loss, self.params)
        #updates = lasagne.updates.momentum(self.loss, self.params, learning_rate=0.001)

        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var,
                                                    self.fact_count_var, self.input_mask_var],
                                            outputs=[self.prediction, self.loss],
                                            updates=updates)

        print "==> compiling test_fn"
        self.test_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var,
                                               self.fact_count_var, self.input_mask_var],
                                       outputs=[self.prediction, self.loss])
Exemple #17
0
    def __init__(self, train_raw, dev_raw, test_raw, word2vec,
                 word_vector_size, dim, mode, input_mask_mode, memory_hops, l2,
                 normalize_attention, dropout, **kwargs):
        print "generate one-word answer for mctest"
        print "==> not used params in DMN class:", kwargs.keys()
        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.vocab_size = len(word2vec)

        self.dim = dim  # hidden state size
        self.mode = mode
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        self.l2 = l2
        self.normalize_attention = normalize_attention
        self.dropout = dropout

        self.train_input, self.train_q, self.train_answer, self.train_input_mask = self._process_input(
            train_raw)
        self.dev_input, self.dev_q, self.dev_answer, self.dev_input_mask = self._process_input(
            dev_raw)
        self.test_input, self.test_q, self.test_answer, self.test_input_mask = self._process_input(
            test_raw)

        self.input_var = T.matrix('input_var')
        self.q_var = T.matrix('question_var')
        self.answer_var = T.iscalar('answer_var')
        self.input_mask_var = T.ivector('input_mask_var')
        self.attentions = []

        print "==> building input module"
        self.W_inp_res_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_inp_upd_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_inp_hid_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        inp_c_history, _ = theano.scan(fn=self.input_gru_step,
                                       sequences=self.input_var,
                                       outputs_info=T.zeros_like(
                                           self.b_inp_hid))

        self.inp_c = inp_c_history.take(self.input_mask_var, axis=0)

        self.q_q, _ = theano.scan(fn=self.input_gru_step,
                                  sequences=self.q_var,
                                  outputs_info=T.zeros_like(self.b_inp_hid))

        self.q_q = self.q_q[-1]

        print "==> creating parameters for memory module"
        self.W_mem_res_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_upd_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_hid_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_1 = nn_utils.normal_param(std=0.1,
                                         shape=(self.dim, 7 * self.dim + 2))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim, ))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1, ))

        print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(
                self.GRU_update(memory[iter - 1], current_episode,
                                self.W_mem_res_in, self.W_mem_res_hid,
                                self.b_mem_res, self.W_mem_upd_in,
                                self.W_mem_upd_hid, self.b_mem_upd,
                                self.W_mem_hid_in, self.W_mem_hid_hid,
                                self.b_mem_hid))

        last_mem_raw = memory[-1].dimshuffle(('x', 0))

        net = layers.InputLayer(shape=(1, self.dim), input_var=last_mem_raw)
        if self.dropout > 0 and self.mode == 'train':
            net = layers.DropoutLayer(net, p=self.dropout)
        last_mem = layers.get_output(net)[0]
        self.attentions = T.stack(self.attentions)

        print "==> building answer module"
        self.W_a = nn_utils.normal_param(std=0.1,
                                         shape=(self.vocab_size, self.dim))

        self.prediction = nn_utils.softmax(T.dot(self.W_a, last_mem))

        print "==> collecting all parameters"
        self.params = [
            self.W_inp_res_in, self.W_inp_res_hid, self.b_inp_res,
            self.W_inp_upd_in, self.W_inp_upd_hid, self.b_inp_upd,
            self.W_inp_hid_in, self.W_inp_hid_hid, self.b_inp_hid,
            self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res,
            self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
            self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid, self.W_b,
            self.W_1, self.W_2, self.b_1, self.b_2, self.W_a
        ]

        print "==> building loss layer and computing updates"
        self.loss_ce = T.nnet.categorical_crossentropy(
            self.prediction.dimshuffle('x', 0), T.stack([self.answer_var]))[0]

        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0

        self.loss = self.loss_ce + self.loss_l2

        updates = lasagne.updates.adam(self.loss, self.params)
        #updates = lasagne.updates.momentum(self.loss, self.params, learning_rate=0.0003)

        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(
                inputs=[
                    self.input_var, self.q_var, self.answer_var,
                    self.input_mask_var
                ],
                allow_input_downcast=True,
                outputs=[self.prediction, self.loss],
                updates=updates)

        print "==> compiling test_fn"
        self.test_fn = theano.function(
            inputs=[
                self.input_var, self.q_var, self.answer_var,
                self.input_mask_var
            ],
            allow_input_downcast=True,
            outputs=[self.prediction, self.loss, self.attentions])
Exemple #18
0
 def get_output(self, X):
     """Perform the forward step transformation."""
     return softmax(X)
    def __init__(self, data_dir, word2vec, word_vector_size, dim, mode,
                 answer_module, memory_hops, batch_size, l2,
                 normalize_attention, batch_norm, dropout, **kwargs):

        print "==> not used params in DMN class:", kwargs.keys()

        self.data_dir = data_dir

        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.mode = mode
        self.answer_module = answer_module
        self.memory_hops = memory_hops
        self.batch_size = batch_size
        self.l2 = l2
        self.normalize_attention = normalize_attention
        self.batch_norm = batch_norm
        self.dropout = dropout

        self.vocab, self.ivocab = self._load_vocab(self.data_dir)

        self.train_story = None
        self.test_story = None
        self.train_dict_story, self.train_features, self.train_fns_dict, self.train_num_imgs = self._process_input_sind(
            self.data_dir, 'train')
        self.test_dict_story, self.test_features, self.test_fns_dict, self.test_num_imgs = self._process_input_sind(
            self.data_dir, 'val')

        self.train_story = self.train_dict_story.keys()
        self.test_story = self.test_dict_story.keys()
        self.vocab_size = len(self.vocab)

        self.input_var = T.tensor3(
            'input_var')  # (batch_size, seq_len, cnn_dim)
        self.q_var = T.matrix('q_var')  # Now, it's a batch * image_sieze.
        self.answer_var = T.imatrix(
            'answer_var')  # answer of example in minibatch
        self.answer_mask = T.matrix('answer_mask')
        self.answer_inp_var = T.tensor3(
            'answer_inp_var')  # answer of example in minibatch

        print "==> building input module"
        # It's very simple now, the input module just need to map from cnn_dim to dim.
        logging.info('self.cnn_dim = %d', self.cnn_dim)
        self.W_inp_emb_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim,
                                                         self.cnn_dim))
        self.b_inp_emb_in = nn_utils.constant_param(value=0.0,
                                                    shape=(self.dim, ))

        #inp_c_hist = T.dot(self.W_inp_emb_in, self.input_var) + self.b_inp_emb_in
        inp_var_shuffled = self.input_var.dimshuffle(1, 2, 0)
        print inp_var_shuffled.shape.eval(
            {self.input_var: np.random.rand(10, 4, 4096).astype('float32')})

        def _dot(x, W, b):
            return T.dot(W, x) + b.dimshuffle(0, 'x')

        inp_c_hist, _ = theano.scan(
            fn=_dot,
            sequences=inp_var_shuffled,
            non_sequences=[self.W_inp_emb_in, self.b_inp_emb_in])
        #inp_c_hist,_ = theano.scan(fn = _dot, sequences=self.input_var, non_sequences = [self.W_inp_emb_in, self.b_inp_emb_in])

        #self.inp_c = inp_c_hist.dimshuffle(2,0,1) # b x len x fea
        self.inp_c = inp_c_hist

        print "==> building question module"
        # Now, share the parameter with the input module.
        q_var_shuffled = self.q_var.dimshuffle(1, 0)
        q_hist = T.dot(self.W_inp_emb_in,
                       q_var_shuffled) + self.b_inp_emb_in.dimshuffle(0, 'x')

        self.q_q = q_hist.dimshuffle(0, 1)  # batch x dim

        print "==> creating parameters for memory module"
        self.W_mem_res_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_upd_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_hid_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_1 = nn_utils.normal_param(std=0.1,
                                         shape=(self.dim, 7 * self.dim + 0))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim, ))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1, ))

        print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            #m = printing.Print('mem')(memory[iter-1])
            current_episode = self.new_episode(memory[iter - 1])
            #current_episode = self.new_episode(m)
            #current_episode = printing.Print('current_episode')(current_episode)
            memory.append(
                self.GRU_update(memory[iter - 1], current_episode,
                                self.W_mem_res_in, self.W_mem_res_hid,
                                self.b_mem_res, self.W_mem_upd_in,
                                self.W_mem_upd_hid, self.b_mem_upd,
                                self.W_mem_hid_in, self.W_mem_hid_hid,
                                self.b_mem_hid))

        last_mem_raw = memory[-1].dimshuffle((1, 0))

        net = layers.InputLayer(shape=(self.batch_size, self.dim),
                                input_var=last_mem_raw)

        if self.batch_norm:
            net = layers.BatchNormLayer(incoming=net)
        if self.dropout > 0 and self.mode == 'train':
            net = layers.DropoutLayer(net, p=self.dropout)
        last_mem = layers.get_output(net).dimshuffle((1, 0))

        logging.info('last_mem size')
        print last_mem.shape.eval({
            self.input_var:
            np.random.rand(10, 4, 4096).astype('float32'),
            self.q_var:
            np.random.rand(10, 4096).astype('float32')
        })

        print "==> building answer module"

        answer_inp_var_shuffled = self.answer_inp_var.dimshuffle(1, 2, 0)
        # because we have the additional #start token. Thus, we need to add this +1 for all the parameters as well.
        dummy = theano.shared(
            np.zeros((self.vocab_size + 1, self.batch_size), dtype=floatX))

        self.W_a = nn_utils.normal_param(std=0.1,
                                         shape=(self.vocab_size + 1, self.dim))

        self.W_ans_res_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim +
                                                         self.vocab_size + 1))
        self.W_ans_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_ans_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_ans_upd_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim +
                                                         self.vocab_size + 1))
        self.W_ans_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_ans_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_ans_hid_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim +
                                                         self.vocab_size + 1))
        self.W_ans_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_ans_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        logging.info('answer_inp_var_shuffled size')

        print answer_inp_var_shuffled.shape.eval({
            self.answer_inp_var:
            np.random.rand(10, 18, 8001).astype('float32')
        })

        #last_mem = printing.Print('prob_sm')(last_mem)
        results, _ = theano.scan(fn=self.answer_gru_step,
                                 sequences=answer_inp_var_shuffled,
                                 outputs_info=[last_mem])
        # Assume there is a start token
        print results.shape.eval({
            self.input_var:
            np.random.rand(10, 4, 4096).astype('float32'),
            self.q_var:
            np.random.rand(10, 4096).astype('float32'),
            self.answer_inp_var:
            np.random.rand(10, 18, 8001).astype('float32')
        })
        results = results[0:-1, :, :]  # get rid of the last token.
        print results.shape.eval({
            self.input_var:
            np.random.rand(10, 4, 4096).astype('float32'),
            self.q_var:
            np.random.rand(10, 4096).astype('float32'),
            self.answer_inp_var:
            np.random.rand(10, 18, 8001).astype('float32')
        })

        # Now, we need to transform it to the probabilities.

        prob, _ = theano.scan(fn=lambda x, w: T.dot(w, x),
                              sequences=results,
                              non_sequences=self.W_a)

        prob_shuffled = prob.dimshuffle(2, 0, 1)  # b * len * vocab

        logging.info("prob shape.")
        print prob.shape.eval({
            self.input_var:
            np.random.rand(10, 4, 4096).astype('float32'),
            self.q_var:
            np.random.rand(10, 4096).astype('float32'),
            self.answer_inp_var:
            np.random.rand(10, 18, 8001).astype('float32')
        })

        n = prob_shuffled.shape[0] * prob_shuffled.shape[1]
        prob_rhp = T.reshape(prob_shuffled, (n, prob_shuffled.shape[2]))
        prob_sm = nn_utils.softmax(prob_rhp)
        self.prediction = prob_sm

        mask = T.reshape(self.answer_mask, (n, ))
        lbl = T.reshape(self.answer_var, (n, ))

        self.params = [
            self.W_inp_emb_in,
            self.b_inp_emb_in,
            self.W_mem_res_in,
            self.W_mem_res_hid,
            self.b_mem_res,
            self.W_mem_upd_in,
            self.W_mem_upd_hid,
            self.b_mem_upd,
            self.W_mem_hid_in,
            self.W_mem_hid_hid,
            self.b_mem_hid,  #self.W_b
            self.W_1,
            self.W_2,
            self.b_1,
            self.b_2,
            self.W_a
        ]

        self.params = self.params + [
            self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res,
            self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
            self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid
        ]

        print "==> building loss layer and computing updates"
        loss_vec = T.nnet.categorical_crossentropy(prob_sm, lbl)
        self.loss_ce = (mask * loss_vec).sum() / mask.sum()

        #self.loss_ce = T.nnet.categorical_crossentropy(results_rhp, lbl)

        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0

        self.loss = self.loss_ce + self.loss_l2

        updates = lasagne.updates.adadelta(self.loss, self.params)
        #updates = lasagne.updates.momentum(self.loss, self.params, learning_rate=0.001)

        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(
                inputs=[
                    self.input_var, self.q_var, self.answer_var,
                    self.answer_mask, self.answer_inp_var
                ],
                outputs=[self.prediction, self.loss],
                updates=updates)

        print "==> compiling test_fn"
        self.test_fn = theano.function(inputs=[
            self.input_var, self.q_var, self.answer_var, self.answer_mask,
            self.answer_inp_var
        ],
                                       outputs=[self.prediction, self.loss])
Exemple #20
0
    def __init__(self, babi_train_raw, babi_test_raw, word2vec,
                 word_vector_size, sent_vector_size, dim, mode, answer_module,
                 input_mask_mode, memory_hops, l2, normalize_attention,
                 batch_norm, dropout, dropout_in, **kwargs):

        print "==> not used params in DMN class:", kwargs.keys()
        self.vocab = {None: 0}
        self.ivocab = {0: None}

        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.sent_vector_size = sent_vector_size
        self.dim = dim
        self.mode = mode
        self.answer_module = answer_module
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        self.l2 = l2
        self.normalize_attention = normalize_attention
        self.batch_norm = batch_norm
        self.dropout = dropout
        self.dropout_in = dropout_in

        self.max_inp_sent_len = 0
        self.max_q_len = 0
        """
        #To Use All Vocab
        self.vocab = {None: 0, 'jason': 134.0, 'office': 14.0, 'yellow': 78.0, 'bedroom': 24.0, 'go': 108.0, 'yes': 15.0, 'antoine': 138.0, 'milk': 139.0, 'before': 46.0, 'grabbed': 128.0, 'fit': 100.0, 'how': 105.0, 'swan': 73.0, 'than': 96.0, 'to': 13.0, 'does': 99.0, 's,e': 110.0, 'east': 102.0, 'rectangle': 82.0, 'gave': 149.0, 'then': 39.0, 'evening': 48.0, 'triangle': 79.0, 'garden': 37.0, 'get': 131.0, 'football,apple,milk': 179.0, 'they': 41.0, 'not': 178.0, 'bigger': 95.0, 'gray': 77.0, 'school': 6.0, 'apple': 142.0, 'did': 127.0, 'morning': 44.0, 'discarded': 146.0, 'julius': 72.0, 'she': 29.0, 'went': 11.0, 'where': 30.0, 'jeff': 152.0, 'square': 84.0, 'who': 153.0, 'tired': 124.0, 'there': 130.0, 'back': 12.0, 'lion': 70.0, 'are': 50.0, 'picked': 143.0, 'e,e': 119.0, 'pajamas': 129.0, 'Mary': 157.0, 'blue': 83.0, 'what': 63.0, 'container': 98.0, 'rhino': 76.0, 'daniel': 31.0, 'bernhard': 67.0, 'milk,football': 172.0, 'above': 80.0, 'got': 136.0, 'emily': 60.0, 'red': 88.0, 'either': 3.0, 'sheep': 58.0, 'football': 137.0, 'jessica': 61.0, 'do': 106.0, 'Bill': 155.0, 'football,apple': 168.0, 'fred': 1.0, 'winona': 59.0, 'objects': 161.0, 'put': 147.0, 'kitchen': 17.0, 'box': 90.0, 'received': 154.0, 'journeyed': 25.0, 'of': 52.0, 'wolf': 62.0, 'afternoon': 47.0, 'or': 7.0, 'south': 112.0, 's,w': 114.0, 'afterwards': 32.0, 'sumit': 123.0, 'color': 75.0, 'julie': 23.0, 'one': 163.0, 'down': 148.0, 'nothing': 167.0, 'n,n': 113.0, 'right': 86.0, 's,s': 116.0, 'gertrude': 54.0, 'bathroom': 26.0, 'from': 109.0, 'west': 104.0, 'chocolates': 91.0, 'two': 165.0, 'frog': 66.0, '.': 9.0, 'cats': 57.0, 'apple,milk,football': 175.0, 'passed': 158.0, 'apple,football,milk': 176.0, 'white': 71.0, 'john': 35.0, 'was': 45.0, 'mary': 10.0, 'apple,football': 170.0, 'north': 103.0, 'n,w': 111.0, 'that': 28.0, 'park': 8.0, 'took': 141.0, 'chocolate': 101.0, 'carrying': 162.0, 'n,e': 120.0, 'mice': 49.0, 'travelled': 22.0, 'he': 33.0, 'none': 164.0, 'bored': 133.0, 'e,n': 117.0, None: 0, 'Jeff': 159.0, 'this': 43.0, 'inside': 93.0, 'bill': 16.0, 'up': 144.0, 'cat': 64.0, 'will': 125.0, 'below': 87.0, 'greg': 74.0, 'three': 166.0, 'suitcase': 97.0, 'following': 36.0, 'e,s': 115.0, 'and': 40.0, 'thirsty': 135.0, 'cinema': 19.0, 'is': 2.0, 'moved': 18.0, 'yann': 132.0, 'sphere': 89.0, 'dropped': 145.0, 'in': 4.0, 'mouse': 56.0, 'football,milk': 171.0, 'pink': 81.0, 'afraid': 51.0, 'no': 20.0, 'Fred': 156.0, 'w,s': 121.0, 'handed': 151.0, 'w,w': 118.0, 'brian': 69.0, 'chest': 94.0, 'w,n': 122.0, 'you': 107.0, 'many': 160.0, 'lily': 65.0, 'hallway': 34.0, 'why': 126.0, 'after': 27.0, 'yesterday': 42.0, 'sandra': 38.0, 'fits': 92.0, 'milk,football,apple': 173.0, 'the': 5.0, 'milk,apple': 169.0, 'a': 55.0, 'give': 150.0, 'longer': 177.0, 'maybe': 21.0, 'hungry': 140.0, 'apple,milk': 174.0, 'green': 68.0, 'wolves': 53.0, 'left': 85.0}
        self.ivocab = {0: None, 1: 'fred', 2: 'is', 3: 'either', 4: 'in', 5: 'the', 6: 'school', 7: 'or', 8: 'park', 9: '.', 10: 'mary', 11: 'went', 12: 'back', 13: 'to', 14: 'office', 15: 'yes', 16: 'bill', 17: 'kitchen', 18: 'moved', 19: 'cinema', 20: 'no', 21: 'maybe', 22: 'travelled', 23: 'julie', 24: 'bedroom', 25: 'journeyed', 26: 'bathroom', 27: 'after', 28: 'that', 29: 'she', 30: 'where', 31: 'daniel', 32: 'afterwards', 33: 'he', 34: 'hallway', 35: 'john', 36: 'following', 37: 'garden', 38: 'sandra', 39: 'then', 40: 'and', 41: 'they', 42: 'yesterday', 43: 'this', 44: 'morning', 45: 'was', 46: 'before', 47: 'afternoon', 48: 'evening', 49: 'mice', 50: 'are', 51: 'afraid', 52: 'of', 53: 'wolves', 54: 'gertrude', 55: 'a', 56: 'mouse', 57: 'cats', 58: 'sheep', 59: 'winona', 60: 'emily', 61: 'jessica', 62: 'wolf', 63: 'what', 64: 'cat', 65: 'lily', 66: 'frog', 67: 'bernhard', 68: 'green', 69: 'brian', 70: 'lion', 71: 'white', 72: 'julius', 73: 'swan', 74: 'greg', 75: 'color', 76: 'rhino', 77: 'gray', 78: 'yellow', 79: 'triangle', 80: 'above', 81: 'pink', 82: 'rectangle', 83: 'blue', 84: 'square', 85: 'left', 86: 'right', 87: 'below', 88: 'red', 89: 'sphere', 90: 'box', 91: 'chocolates', 92: 'fits', 93: 'inside', 94: 'chest', 95: 'bigger', 96: 'than', 97: 'suitcase', 98: 'container', 99: 'does', 100: 'fit', 101: 'chocolate', 102: 'east', 103: 'north', 104: 'west', 105: 'how', 106: 'do', 107: 'you', 108: 'go', 109: 'from', 110: 's,e', 111: 'n,w', 112: 'south', 113: 'n,n', 114: 's,w', 115: 'e,s', 116: 's,s', 117: 'e,n', 118: 'w,w', 119: 'e,e', 120: 'n,e', 121: 'w,s', 122: 'w,n', 123: 'sumit', 124: 'tired', 125: 'will', 126: 'why', 127: 'did', 128: 'grabbed', 129: 'pajamas', 130: 'there', 131: 'get', 132: 'yann', 133: 'bored', 134: 'jason', 135: 'thirsty', 136: 'got', 137: 'football', 138: 'antoine', 139: 'milk', 140: 'hungry', 141: 'took', 142: 'apple', 143: 'picked', 144: 'up', 145: 'dropped', 146: 'discarded', 147: 'put', 148: 'down', 149: 'gave', 150: 'give', 151: 'handed', 152: 'jeff', 153: 'who', 154: 'received', 155: 'Bill', 156: 'Fred', 157: 'Mary', 158: 'passed', 159: 'Jeff', 160: 'many', 161: 'objects', 162: 'carrying', 163: 'one', 164: 'none', 165: 'two', 166: 'three', 167: 'nothing', 168: 'football,apple', 169: 'milk,apple', 170: 'apple,football', 171: 'football,milk', 172: 'milk,football', 173: 'milk,football,apple', 174: 'apple,milk', 175: 'apple,milk,football', 176: 'apple,football,milk', 177: 'longer', 178: 'not', 179: 'football,apple,milk'}
        #self.vocab = {'jason': 134.0, 'office': 14.0, 'yellow': 78.0, 'bedroom': 24.0, 'go': 108.0, 'yes': 15.0, 'antoine': 138.0, 'milk': 139.0, 'before': 46.0, 'grabbed': 128.0, 'fit': 100.0, 'how': 105.0, 'swan': 73.0, 'than': 96.0, 'to': 13.0, 'does': 99.0, 's,e': 110.0, 'east': 102.0, 'rectangle': 82.0, 'gave': 149.0, 'then': 39.0, 'evening': 48.0, 'triangle': 79.0, 'garden': 37.0, 'get': 131.0, 'football,apple,milk': 179.0, 'they': 41.0, 'not': 178.0, 'bigger': 95.0, 'gray': 77.0, 'school': 6.0, 'apple': 142.0, 'did': 127.0, 'morning': 44.0, 'discarded': 146.0, 'julius': 72.0, 'she': 29.0, 'went': 11.0, 'where': 30.0, 'jeff': 152.0, 'square': 84.0, 'who': 153.0, 'tired': 124.0, 'there': 130.0, 'back': 12.0, 'lion': 70.0, 'are': 50.0, 'picked': 143.0, 'e,e': 119.0, 'pajamas': 129.0, 'Mary': 157.0, 'blue': 83.0, 'what': 63.0, 'container': 98.0, 'rhino': 76.0, 'daniel': 31.0, 'bernhard': 67.0, 'milk,football': 172.0, 'above': 80.0, 'got': 136.0, 'emily': 60.0, 'red': 88.0, 'either': 3.0, 'sheep': 58.0, 'football': 137.0, 'jessica': 61.0, 'do': 106.0, 'Bill': 155.0, 'football,apple': 168.0, 'fred': 1.0, 'winona': 59.0, 'objects': 161.0, 'put': 147.0, 'kitchen': 17.0, 'box': 90.0, 'received': 154.0, 'journeyed': 25.0, 'of': 52.0, 'wolf': 62.0, 'afternoon': 47.0, 'or': 7.0, 'south': 112.0, 's,w': 114.0, 'afterwards': 32.0, 'sumit': 123.0, 'color': 75.0, 'julie': 23.0, 'one': 163.0, 'down': 148.0, 'nothing': 167.0, 'n,n': 113.0, 'right': 86.0, 's,s': 116.0, 'gertrude': 54.0, 'bathroom': 26.0, 'from': 109.0, 'west': 104.0, 'chocolates': 91.0, 'two': 165.0, 'frog': 66.0, '.': 9.0, 'cats': 57.0, 'apple,milk,football': 175.0, 'passed': 158.0, 'apple,football,milk': 176.0, 'white': 71.0, 'john': 35.0, 'was': 45.0, 'mary': 10.0, 'apple,football': 170.0, 'north': 103.0, 'n,w': 111.0, 'that': 28.0, 'park': 8.0, 'took': 141.0, 'chocolate': 101.0, 'carrying': 162.0, 'n,e': 120.0, 'mice': 49.0, 'travelled': 22.0, 'he': 33.0, 'none': 164.0, 'bored': 133.0, 'e,n': 117.0, None: 0, 'Jeff': 159.0, 'this': 43.0, 'inside': 93.0, 'bill': 16.0, 'up': 144.0, 'cat': 64.0, 'will': 125.0, 'below': 87.0, 'greg': 74.0, 'three': 166.0, 'suitcase': 97.0, 'following': 36.0, 'e,s': 115.0, 'and': 40.0, 'thirsty': 135.0, 'cinema': 19.0, 'is': 2.0, 'moved': 18.0, 'yann': 132.0, 'sphere': 89.0, 'dropped': 145.0, 'in': 4.0, 'mouse': 56.0, 'football,milk': 171.0, 'pink': 81.0, 'afraid': 51.0, 'no': 20.0, 'Fred': 156.0, 'w,s': 121.0, 'handed': 151.0, 'w,w': 118.0, 'brian': 69.0, 'chest': 94.0, 'w,n': 122.0, 'you': 107.0, 'many': 160.0, 'lily': 65.0, 'hallway': 34.0, 'why': 126.0, 'after': 27.0, 'yesterday': 42.0, 'sandra': 38.0, 'fits': 92.0, 'milk,football,apple': 173.0, 'the': 5.0, 'milk,apple': 169.0, 'a': 55.0, 'give': 150.0, 'longer': 177.0, 'maybe': 21.0, 'hungry': 140.0, 'apple,milk': 174.0, 'green': 68.0, 'wolves': 53.0, 'left': 85.0}
        #self.ivocab = {1: 'fred', 2: 'is', 3: 'either', 4: 'in', 5: 'the', 6: 'school', 7: 'or', 8: 'park', 9: '.', 10: 'mary', 11: 'went', 12: 'back', 13: 'to', 14: 'office', 15: 'yes', 16: 'bill', 17: 'kitchen', 18: 'moved', 19: 'cinema', 20: 'no', 21: 'maybe', 22: 'travelled', 23: 'julie', 24: 'bedroom', 25: 'journeyed', 26: 'bathroom', 27: 'after', 28: 'that', 29: 'she', 30: 'where', 31: 'daniel', 32: 'afterwards', 33: 'he', 34: 'hallway', 35: 'john', 36: 'following', 37: 'garden', 38: 'sandra', 39: 'then', 40: 'and', 41: 'they', 42: 'yesterday', 43: 'this', 44: 'morning', 45: 'was', 46: 'before', 47: 'afternoon', 48: 'evening', 49: 'mice', 50: 'are', 51: 'afraid', 52: 'of', 53: 'wolves', 54: 'gertrude', 55: 'a', 56: 'mouse', 57: 'cats', 58: 'sheep', 59: 'winona', 60: 'emily', 61: 'jessica', 62: 'wolf', 63: 'what', 64: 'cat', 65: 'lily', 66: 'frog', 67: 'bernhard', 68: 'green', 69: 'brian', 70: 'lion', 71: 'white', 72: 'julius', 73: 'swan', 74: 'greg', 75: 'color', 76: 'rhino', 77: 'gray', 78: 'yellow', 79: 'triangle', 80: 'above', 81: 'pink', 82: 'rectangle', 83: 'blue', 84: 'square', 85: 'left', 86: 'right', 87: 'below', 88: 'red', 89: 'sphere', 90: 'box', 91: 'chocolates', 92: 'fits', 93: 'inside', 94: 'chest', 95: 'bigger', 96: 'than', 97: 'suitcase', 98: 'container', 99: 'does', 100: 'fit', 101: 'chocolate', 102: 'east', 103: 'north', 104: 'west', 105: 'how', 106: 'do', 107: 'you', 108: 'go', 109: 'from', 110: 's,e', 111: 'n,w', 112: 'south', 113: 'n,n', 114: 's,w', 115: 'e,s', 116: 's,s', 117: 'e,n', 118: 'w,w', 119: 'e,e', 120: 'n,e', 121: 'w,s', 122: 'w,n', 123: 'sumit', 124: 'tired', 125: 'will', 126: 'why', 127: 'did', 128: 'grabbed', 129: 'pajamas', 130: 'there', 131: 'get', 132: 'yann', 133: 'bored', 134: 'jason', 135: 'thirsty', 136: 'got', 137: 'football', 138: 'antoine', 139: 'milk', 140: 'hungry', 141: 'took', 142: 'apple', 143: 'picked', 144: 'up', 145: 'dropped', 146: 'discarded', 147: 'put', 148: 'down', 149: 'gave', 150: 'give', 151: 'handed', 152: 'jeff', 153: 'who', 154: 'received', 155: 'Bill', 156: 'Fred', 157: 'Mary', 158: 'passed', 159: 'Jeff', 160: 'many', 161: 'objects', 162: 'carrying', 163: 'one', 164: 'none', 165: 'two', 166: 'three', 167: 'nothing', 168: 'football,apple', 169: 'milk,apple', 170: 'apple,football', 171: 'football,milk', 172: 'milk,football', 173: 'milk,football,apple', 174: 'apple,milk', 175: 'apple,milk,football', 176: 'apple,football,milk', 177: 'longer', 178: 'not', 179: 'football,apple,milk'}
        #"""

        self.train_input, self.train_q, self.train_answer, self.train_input_mask = self._process_input(
            babi_train_raw)
        self.test_input, self.test_q, self.test_answer, self.test_input_mask = self._process_input(
            babi_test_raw)
        self.vocab_size = len(self.vocab)

        self.input_var = T.imatrix('input_var')
        self.q_var = T.ivector('question_var')
        self.answer_var = T.iscalar('answer_var')
        self.input_mask_var = T.ivector('input_mask_var')

        self.attentions = []

        self.pe_matrix_in = self.pe_matrix(self.max_inp_sent_len)
        self.pe_matrix_q = self.pe_matrix(self.max_q_len)

        print "==> building input module"

        #positional encoder weights
        self.W_pe = nn_utils.normal_param(std=0.1,
                                          shape=(self.vocab_size, self.dim))

        #biGRU input fusion weights
        self.W_inp_res_in_fwd = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.sent_vector_size))
        self.W_inp_res_hid_fwd = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
        self.b_inp_res_fwd = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

        self.W_inp_upd_in_fwd = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.sent_vector_size))
        self.W_inp_upd_hid_fwd = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
        self.b_inp_upd_fwd = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

        self.W_inp_hid_in_fwd = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.sent_vector_size))
        self.W_inp_hid_hid_fwd = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
        self.b_inp_hid_fwd = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

        self.W_inp_res_in_bwd = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.sent_vector_size))
        self.W_inp_res_hid_bwd = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
        self.b_inp_res_bwd = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

        self.W_inp_upd_in_bwd = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.sent_vector_size))
        self.W_inp_upd_hid_bwd = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
        self.b_inp_upd_bwd = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

        self.W_inp_hid_in_bwd = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.sent_vector_size))
        self.W_inp_hid_hid_bwd = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
        self.b_inp_hid_bwd = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

        #self.V_f = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        #self.V_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))

        self.inp_sent_reps, _ = theano.scan(fn=self.sum_pos_encodings_in,
                                            sequences=self.input_var)

        self.inp_sent_reps_stacked = T.stacklists(self.inp_sent_reps)
        #self.inp_c = self.input_module_full(self.inp_sent_reps_stacked)

        self.inp_c = self.input_module_full(self.inp_sent_reps)

        self.q_q = self.sum_pos_encodings_q(self.q_var)

        print "==> creating parameters for memory module"
        self.W_mem_res_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.memory_hops,
                                                         self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.memory_hops,
                                                          self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0,
                                                 shape=(
                                                     self.memory_hops,
                                                     self.dim,
                                                 ))

        self.W_mem_upd_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.memory_hops,
                                                         self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.memory_hops,
                                                          self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0,
                                                 shape=(
                                                     self.memory_hops,
                                                     self.dim,
                                                 ))

        self.W_mem_hid_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.memory_hops,
                                                         self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.memory_hops,
                                                          self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0,
                                                 shape=(
                                                     self.memory_hops,
                                                     self.dim,
                                                 ))

        #self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        #self.W_1 = nn_utils.normal_param(std=0.1, shape=(self.dim, 7 * self.dim + 0))
        self.W_1 = nn_utils.normal_param(std=0.1,
                                         shape=(self.memory_hops, self.dim,
                                                4 * self.dim + 0))
        self.W_2 = nn_utils.normal_param(std=0.1,
                                         shape=(self.memory_hops, 1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0,
                                           shape=(
                                               self.memory_hops,
                                               self.dim,
                                           ))
        self.b_2 = nn_utils.constant_param(value=0.0,
                                           shape=(
                                               self.memory_hops,
                                               1,
                                           ))

        print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            self.mem_weight_num = int(iter - 1)
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(
                self.GRU_update(memory[iter - 1], current_episode,
                                self.W_mem_res_in[self.mem_weight_num],
                                self.W_mem_res_hid[self.mem_weight_num],
                                self.b_mem_res[self.mem_weight_num],
                                self.W_mem_upd_in[self.mem_weight_num],
                                self.W_mem_upd_hid[self.mem_weight_num],
                                self.b_mem_upd[self.mem_weight_num],
                                self.W_mem_hid_in[self.mem_weight_num],
                                self.W_mem_hid_hid[self.mem_weight_num],
                                self.b_mem_hid[self.mem_weight_num]))

        last_mem_raw = memory[-1].dimshuffle(('x', 0))

        net = layers.InputLayer(shape=(1, self.dim), input_var=last_mem_raw)
        if self.dropout > 0 and self.mode == 'train':
            net = layers.DropoutLayer(net, p=self.dropout)
        last_mem = layers.get_output(net)[0]

        print "==> building answer module"
        self.W_a = nn_utils.normal_param(std=0.1,
                                         shape=(self.vocab_size, self.dim))

        if self.answer_module == 'feedforward':
            self.prediction = nn_utils.softmax(T.dot(self.W_a, last_mem))

        elif self.answer_module == 'recurrent':
            self.W_ans_res_in = nn_utils.normal_param(
                std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_res_hid = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
            self.b_ans_res = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

            self.W_ans_upd_in = nn_utils.normal_param(
                std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_upd_hid = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
            self.b_ans_upd = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

            self.W_ans_hid_in = nn_utils.normal_param(
                std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_hid_hid = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
            self.b_ans_hid = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

            def answer_step(prev_a, prev_y):
                a = self.GRU_update(prev_a, T.concatenate([prev_y, self.q_q]),
                                    self.W_ans_res_in, self.W_ans_res_hid,
                                    self.b_ans_res, self.W_ans_upd_in,
                                    self.W_ans_upd_hid, self.b_ans_upd,
                                    self.W_ans_hid_in, self.W_ans_hid_hid,
                                    self.b_ans_hid)

                y = nn_utils.softmax(T.dot(self.W_a, a))
                return [a, y]

            # add conditional ending?
            dummy = theano.shared(np.zeros((self.vocab_size, ), dtype=floatX))

            results, updates = theano.scan(
                fn=answer_step,
                outputs_info=[last_mem, T.zeros_like(dummy)],
                n_steps=1)
            self.prediction = results[1][-1]

        else:
            raise Exception("invalid answer_module")

        print "==> collecting all parameters"
        self.params = [
            self.W_pe,
            self.W_inp_res_in_fwd,
            self.W_inp_res_hid_fwd,
            self.b_inp_res_fwd,
            self.W_inp_upd_in_fwd,
            self.W_inp_upd_hid_fwd,
            self.b_inp_upd_fwd,
            self.W_inp_hid_in_fwd,
            self.W_inp_hid_hid_fwd,
            self.b_inp_hid_fwd,
            self.W_inp_res_in_bwd,
            self.W_inp_res_hid_bwd,
            self.b_inp_res_bwd,
            self.W_inp_upd_in_bwd,
            self.W_inp_upd_hid_bwd,
            self.b_inp_upd_bwd,
            self.W_inp_hid_in_bwd,
            self.W_inp_hid_hid_bwd,
            self.b_inp_hid_bwd,
            self.W_mem_res_in,
            self.W_mem_res_hid,
            self.b_mem_res,
            self.W_mem_upd_in,
            self.W_mem_upd_hid,
            self.b_mem_upd,
            self.W_mem_hid_in,
            self.W_mem_hid_hid,
            self.b_mem_hid,  #self.W_b
            self.W_1,
            self.W_2,
            self.b_1,
            self.b_2,
            self.W_a
        ]

        if self.answer_module == 'recurrent':
            self.params = self.params + [
                self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res,
                self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid
            ]

        print "==> building loss layer and computing updates"
        self.loss_ce = T.nnet.categorical_crossentropy(
            self.prediction.dimshuffle('x', 0), T.stack([self.answer_var]))[0]

        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0

        self.loss = self.loss_ce + self.loss_l2

        #updates = lasagne.updates.adadelta(self.loss, self.params)
        updates = lasagne.updates.adam(self.loss, self.params)
        updates = lasagne.updates.adam(self.loss,
                                       self.params,
                                       learning_rate=0.0001,
                                       beta1=0.5)  #from DCGAN paper
        #updates = lasagne.updates.adadelta(self.loss, self.params, learning_rate=0.0005)
        #updates = lasagne.updates.momentum(self.loss, self.params, learning_rate=0.0003)

        self.attentions = T.stack(self.attentions)
        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(
                inputs=[
                    self.input_var, self.q_var, self.answer_var,
                    self.input_mask_var
                ],
                outputs=[self.prediction, self.loss, self.attentions],
                updates=updates,
                on_unused_input='warn',
                allow_input_downcast=True)

        print "==> compiling test_fn"
        self.test_fn = theano.function(
            inputs=[
                self.input_var, self.q_var, self.answer_var,
                self.input_mask_var
            ],
            outputs=[self.prediction, self.loss, self.attentions],
            on_unused_input='warn',
            allow_input_downcast=True)
    def __init__(self, babi_train_raw, babi_test_raw, word2vec, word_vector_size, dim, 
                mode, answer_module, input_mask_mode, memory_hops, batch_size, l2,
                normalize_attention):
        
        self.vocab = {}
        self.ivocab = {}
        
        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.mode = mode
        self.answer_module = answer_module
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        self.batch_size = batch_size
        self.l2 = l2
        self.normalize_attention = normalize_attention

        self.max_fact_count = 0
        
        self.train_input, self.train_q, self.train_answer, self.train_fact_count, self.train_input_mask = self._process_input(babi_train_raw)
        self.test_input, self.test_q, self.test_answer, self.test_fact_count, self.test_input_mask = self._process_input(babi_test_raw)
        self.vocab_size = len(self.vocab)
        
        self.input_var = T.tensor3('input_var') # (batch_size, seq_len, glove_dim)
        self.q_var = T.tensor3('question_var') # as self.input_var
        self.answer_var = T.ivector('answer_var') # answer of example in minibatch
        self.fact_count_var = T.ivector('fact_count_var') # number of facts in the example of minibatch
        self.input_mask_var = T.imatrix('input_mask_var') # (batch_size, indices) 
        
        print "==> building input module"
        self.W_inp_res_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.word_vector_size)), borrow=True)
        self.W_inp_res_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_inp_res = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        self.W_inp_upd_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.word_vector_size)), borrow=True)
        self.W_inp_upd_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_inp_upd = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        self.W_inp_hid_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.word_vector_size)), borrow=True)
        self.W_inp_hid_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_inp_hid = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
  
        input_var_shuffled = self.input_var.dimshuffle(1, 2, 0)
        inp_dummy = theano.shared(np.zeros((self.dim, self.batch_size), dtype=floatX))
        inp_c_history, _ = theano.scan(fn=self.input_gru_step, 
                            sequences=input_var_shuffled,
                            outputs_info=T.zeros_like(inp_dummy))
        
        inp_c_history_shuffled = inp_c_history.dimshuffle(2, 0, 1)
        
        inp_c_list = []
        inp_c_mask_list = []
        for batch_index in range(self.batch_size):
            taken = inp_c_history_shuffled[batch_index].take(self.input_mask_var[batch_index, :self.fact_count_var[batch_index]], axis=0)
            inp_c_list.append(T.concatenate([taken, T.zeros((self.max_fact_count - taken.shape[0], self.dim), floatX)]))
            inp_c_mask_list.append(T.concatenate([T.ones((taken.shape[0],), np.int32), T.zeros((self.max_fact_count - taken.shape[0],), np.int32)]))
        
        self.inp_c = T.stack(inp_c_list).dimshuffle(1, 2, 0)
        inp_c_mask = T.stack(inp_c_mask_list).dimshuffle(1, 0)
        
        q_var_shuffled = self.q_var.dimshuffle(1, 2, 0)
        q_dummy = theano.shared(np.zeros((self.dim, self.batch_size), dtype=floatX))
        q_q_history, _ = theano.scan(fn=self.input_gru_step, 
                            sequences=q_var_shuffled,
                            outputs_info=T.zeros_like(q_dummy))
        self.q_q = q_q_history[-1]
        
        
        print "==> creating parameters for memory module"
        self.W_mem_res_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.W_mem_res_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_mem_res = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        self.W_mem_upd_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.W_mem_upd_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_mem_upd = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        self.W_mem_hid_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.W_mem_hid_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_mem_hid = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        
        self.W_b = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.W_1 = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, 7 * self.dim + 2)), borrow=True)
        self.W_2 = theano.shared(lasagne.init.Normal(0.1).sample((1, self.dim)), borrow=True)
        self.b_1 = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        self.b_2 = theano.shared(lasagne.init.Constant(0.0).sample((1,)), borrow=True)
        

        print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(self.GRU_update(memory[iter - 1], current_episode,
                                          self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res, 
                                          self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                                          self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid))                         
        
        last_mem = memory[-1]
        
        
        print "==> building answer module"
        self.W_a = theano.shared(lasagne.init.Normal(0.1).sample((self.vocab_size, self.dim)), borrow=True)
        
        if self.answer_module == 'feedforward':
            self.prediction = nn_utils.softmax(T.dot(self.W_a, last_mem))
        
        elif self.answer_module == 'recurrent':
            self.W_ans_res_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim + self.vocab_size)), borrow=True)
            self.W_ans_res_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
            self.b_ans_res = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
            
            self.W_ans_upd_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim + self.vocab_size)), borrow=True)
            self.W_ans_upd_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
            self.b_ans_upd = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
            
            self.W_ans_hid_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim + self.vocab_size)), borrow=True)
            self.W_ans_hid_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
            self.b_ans_hid = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
            def answer_step(prev_a, prev_y):
                a = self.GRU_update(prev_a, T.concatenate([prev_y, self.q_q]),
                                  self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res, 
                                  self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                                  self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid)
                
                y = nn_utils.softmax(T.dot(self.W_a, a))
                return [a, y]
            
            # TODO: add conditional ending
            dummy = theano.shared(np.zeros((self.vocab_size, self.batch_size), dtype=floatX))
            results, updates = theano.scan(fn=self.answer_step,
                outputs_info=[last_mem, T.zeros_like(dummy)], #(last_mem, y)
                n_steps=1)
            self.prediction = results[1][-1]
        
        else:
            raise Exception("invalid answer_module")
        
        self.prediction = self.prediction.dimshuffle(1, 0)
                
        self.params = [self.W_inp_res_in, self.W_inp_res_hid, self.b_inp_res, 
                  self.W_inp_upd_in, self.W_inp_upd_hid, self.b_inp_upd,
                  self.W_inp_hid_in, self.W_inp_hid_hid, self.b_inp_hid,
                  self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res, 
                  self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                  self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid,
                  self.W_b, self.W_1, self.W_2, self.b_1, self.b_2, self.W_a]
        
        if self.answer_module == 'recurrent':
            self.params = self.params + [self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res, 
                              self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                              self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid]
                              
                              
        print "==> building loss layer and computing updates"
        self.loss_ce = T.nnet.categorical_crossentropy(self.prediction, self.answer_var).mean()
            
        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0
        
        self.loss = self.loss_ce + self.loss_l2
            
        updates = lasagne.updates.adadelta(self.loss, self.params)
        
        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.fact_count_var, self.input_mask_var], 
                                       outputs=[self.prediction, self.loss],
                                       updates=updates,
                                       #mode=NanGuardMode(nan_is_error=True, inf_is_error=True, big_is_error=False)
                                      )
        
        print "==> compiling test_fn"
        self.test_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.fact_count_var, self.input_mask_var],
                                  outputs=[self.prediction, self.loss, self.inp_c, self.q_q, last_mem],
                                  #mode=NanGuardMode(nan_is_error=True, inf_is_error=True, big_is_error=False)
                                  )
        
        
        if self.mode == 'train':
            print "==> computing gradients (for debugging)"
            gradient = T.grad(self.loss, self.params)
            self.get_gradient_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.fact_count_var, self.input_mask_var], outputs=gradient)
    def __init__(self, babi_train_raw, babi_test_raw, word2vec, word_vector_size, 
                dim, mode, input_mask_mode, memory_hops, l2, normalize_attention, **kwargs):

        print "==> not used params in DMN class:", kwargs.keys()
        self.vocab = {}
        self.ivocab = {}
        
        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.mode = mode
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        #self.batch_size = 1
        self.l2 = l2
        self.normalize_attention = normalize_attention

        self.train_input, self.train_q, self.train_answer, self.train_choices, self.train_input_mask = self._process_input(babi_train_raw)
        self.test_input, self.test_q, self.test_answer, self.test_choices, self.test_input_mask = self._process_input(babi_test_raw)
        self.vocab_size = 4 # number of answer choices
        
        self.inp_var = T.matrix('input_var')
        self.q_var = T.matrix('question_var')
        self.ca_var = T.matrix('ca_var')
        self.cb_var = T.matrix('cb_var')
        self.cc_var = T.matrix('cc_var')
        self.cd_var = T.matrix('cd_var')
        self.ans_var = T.iscalar('answer_var')
        self.input_mask_var = T.ivector('input_mask_var')
        
        
        print "==> building input module"
        self.W_inp_res_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.word_vector_size)), borrow=True)
        self.W_inp_res_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_inp_res = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        self.W_inp_upd_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.word_vector_size)), borrow=True)
        self.W_inp_upd_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_inp_upd = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        self.W_inp_hid_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.word_vector_size)), borrow=True)
        self.W_inp_hid_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_inp_hid = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        inp_c_history, _ = theano.scan(fn=self.input_gru_step, 
                    sequences=self.inp_var,
                    outputs_info=T.zeros_like(self.b_inp_hid))
        
        self.inp_c = inp_c_history.take(self.input_mask_var, axis=0)
        
        self.q_q, _ = theano.scan(fn=self.input_gru_step, 
                    sequences=self.q_var,
                    outputs_info=T.zeros_like(self.b_inp_hid))

        self.q_q = self.q_q[-1]
        
        self.c_vecs = []
        for choice in [self.ca_var, self.cb_var, self.cc_var, self.cd_var]:
            history, _ = theano.scan(fn=self.input_gru_step, 
                    sequences=choice,
                    outputs_info=T.zeros_like(self.b_inp_hid))
            self.c_vecs.append(history[-1])
        
        
        self.c_vecs = T.stack(self.c_vecs).transpose((1, 0)) # (dim, 4)
        self.inp_c = T.stack([self.inp_c] * 4).transpose((1, 2, 0)) # (fact_cnt, dim, 4)
        self.q_q = T.stack([self.q_q] * 4).transpose((1, 0)) # (dim, 4)
        
        
        print "==> creating parameters for memory module"
        self.W_mem_res_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.W_mem_res_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_mem_res = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        self.W_mem_upd_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.W_mem_upd_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_mem_upd = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        self.W_mem_hid_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.W_mem_hid_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_mem_hid = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        self.W_b = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.W_1 = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, 10 * self.dim + 3)), borrow=True)
        self.W_2 = theano.shared(lasagne.init.Normal(0.1).sample((1, self.dim)), borrow=True)
        self.b_1 = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        self.b_2 = theano.shared(lasagne.init.Constant(0.0).sample((1,)), borrow=True)
        

        print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
        memory = [self.q_q.copy()] # (dim, 4)
        for iter in range(1, self.memory_hops + 1):
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(self.GRU_update_batch(memory[iter - 1], current_episode,
                                          self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res, 
                                          self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                                          self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid))
                                      
        last_mem = memory[-1].flatten()
        

        print "==> building answer module"
        self.W_a = theano.shared(lasagne.init.Normal(0.1).sample((self.vocab_size, 4 * self.dim)), borrow=True)
        self.prediction = nn_utils.softmax(T.dot(self.W_a, last_mem))
        
        
        print "==> collecting all parameters"
        self.params = [self.W_inp_res_in, self.W_inp_res_hid, self.b_inp_res, 
                  self.W_inp_upd_in, self.W_inp_upd_hid, self.b_inp_upd,
                  self.W_inp_hid_in, self.W_inp_hid_hid, self.b_inp_hid,
                  self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res, 
                  self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                  self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid,
                  self.W_b, self.W_1, self.W_2, self.b_1, self.b_2, self.W_a]
        
        
        print "==> building loss layer and computing updates"
        self.loss_ce = T.nnet.categorical_crossentropy(self.prediction.dimshuffle('x', 0), T.stack([self.ans_var]))[0]
        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0
        
        self.loss = self.loss_ce + self.loss_l2
        
        updates = lasagne.updates.adadelta(self.loss, self.params)
        
        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(inputs=[self.inp_var, self.q_var, self.ans_var,
                                                    self.ca_var, self.cb_var, self.cc_var, self.cd_var,
                                                    self.input_mask_var], 
                                            outputs=[self.prediction, self.loss],
                                            updates=updates)
        
        print "==> compiling test_fn"
        self.test_fn = theano.function(inputs=[self.inp_var, self.q_var, self.ans_var,
                                                    self.ca_var, self.cb_var, self.cc_var, self.cd_var,
                                                    self.input_mask_var], 
                                        outputs=[self.prediction, self.loss, self.inp_c, self.q_q, last_mem])
        
        
        if self.mode == 'train':
            print "==> computing gradients (for debugging)"
            gradient = T.grad(self.loss, self.params)
            self.get_gradient_fn = theano.function(inputs=[self.inp_var, self.q_var, self.ans_var,
                                                    self.ca_var, self.cb_var, self.cc_var, self.cd_var,
                                                    self.input_mask_var], outputs=gradient)
Exemple #23
0
    def __init__(self, babi_train_raw, babi_test_raw, word2vec,
                 word_vector_size, dim, mode, answer_module, input_mask_mode,
                 memory_hops, l2, normalize_attention):

        self.vocab = {}
        self.ivocab = {}

        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.mode = mode
        self.answer_module = answer_module
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        self.l2 = l2
        self.normalize_attention = normalize_attention

        self.train_input, self.train_q, self.train_answer, self.train_input_mask = self._process_input(
            babi_train_raw)
        self.test_input, self.test_q, self.test_answer, self.test_input_mask = self._process_input(
            babi_test_raw)
        self.vocab_size = len(self.vocab)

        self.input_var = T.matrix('input_var')
        self.q_var = T.matrix('question_var')
        self.answer_var = T.iscalar('answer_var')
        self.input_mask_var = T.ivector('input_mask_var')

        print "==> building input module"
        self.W_inp_res_in = theano.shared(lasagne.init.Normal(0.1).sample(
            (self.dim, self.word_vector_size)),
                                          borrow=True)
        self.W_inp_res_hid = theano.shared(lasagne.init.Normal(0.1).sample(
            (self.dim, self.dim)),
                                           borrow=True)
        self.b_inp_res = theano.shared(lasagne.init.Constant(0.0).sample(
            (self.dim, )),
                                       borrow=True)

        self.W_inp_upd_in = theano.shared(lasagne.init.Normal(0.1).sample(
            (self.dim, self.word_vector_size)),
                                          borrow=True)
        self.W_inp_upd_hid = theano.shared(lasagne.init.Normal(0.1).sample(
            (self.dim, self.dim)),
                                           borrow=True)
        self.b_inp_upd = theano.shared(lasagne.init.Constant(0.0).sample(
            (self.dim, )),
                                       borrow=True)

        self.W_inp_hid_in = theano.shared(lasagne.init.Normal(0.1).sample(
            (self.dim, self.word_vector_size)),
                                          borrow=True)
        self.W_inp_hid_hid = theano.shared(lasagne.init.Normal(0.1).sample(
            (self.dim, self.dim)),
                                           borrow=True)
        self.b_inp_hid = theano.shared(lasagne.init.Constant(0.0).sample(
            (self.dim, )),
                                       borrow=True)

        inp_c_history, _ = theano.scan(fn=self.input_gru_step,
                                       sequences=self.input_var,
                                       outputs_info=T.zeros_like(
                                           self.b_inp_hid))

        self.inp_c = inp_c_history.take(self.input_mask_var, axis=0)

        self.q_q, _ = theano.scan(fn=self.input_gru_step,
                                  sequences=self.q_var,
                                  outputs_info=T.zeros_like(self.b_inp_hid))

        self.q_q = self.q_q[-1]

        print "==> creating parameters for memory module"
        self.W_mem_res_in = theano.shared(lasagne.init.Normal(0.1).sample(
            (self.dim, self.dim)),
                                          borrow=True)
        self.W_mem_res_hid = theano.shared(lasagne.init.Normal(0.1).sample(
            (self.dim, self.dim)),
                                           borrow=True)
        self.b_mem_res = theano.shared(lasagne.init.Constant(0.0).sample(
            (self.dim, )),
                                       borrow=True)

        self.W_mem_upd_in = theano.shared(lasagne.init.Normal(0.1).sample(
            (self.dim, self.dim)),
                                          borrow=True)
        self.W_mem_upd_hid = theano.shared(lasagne.init.Normal(0.1).sample(
            (self.dim, self.dim)),
                                           borrow=True)
        self.b_mem_upd = theano.shared(lasagne.init.Constant(0.0).sample(
            (self.dim, )),
                                       borrow=True)

        self.W_mem_hid_in = theano.shared(lasagne.init.Normal(0.1).sample(
            (self.dim, self.dim)),
                                          borrow=True)
        self.W_mem_hid_hid = theano.shared(lasagne.init.Normal(0.1).sample(
            (self.dim, self.dim)),
                                           borrow=True)
        self.b_mem_hid = theano.shared(lasagne.init.Constant(0.0).sample(
            (self.dim, )),
                                       borrow=True)

        self.W_b = theano.shared(lasagne.init.Normal(0.1).sample(
            (self.dim, self.dim)),
                                 borrow=True)
        self.W_1 = theano.shared(lasagne.init.Normal(0.1).sample(
            (self.dim, 7 * self.dim + 2)),
                                 borrow=True)
        self.W_2 = theano.shared(lasagne.init.Normal(0.1).sample(
            (1, self.dim)),
                                 borrow=True)
        self.b_1 = theano.shared(lasagne.init.Constant(0.0).sample(
            (self.dim, )),
                                 borrow=True)
        self.b_2 = theano.shared(lasagne.init.Constant(0.0).sample((1, )),
                                 borrow=True)

        print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(
                self.GRU_update(memory[iter - 1], current_episode,
                                self.W_mem_res_in, self.W_mem_res_hid,
                                self.b_mem_res, self.W_mem_upd_in,
                                self.W_mem_upd_hid, self.b_mem_upd,
                                self.W_mem_hid_in, self.W_mem_hid_hid,
                                self.b_mem_hid))

        last_mem = memory[-1]

        print "==> building answer module"
        self.W_a = theano.shared(lasagne.init.Normal(0.1).sample(
            (self.vocab_size, self.dim)),
                                 borrow=True)

        if self.answer_module == 'feedforward':
            self.prediction = nn_utils.softmax(T.dot(self.W_a, last_mem))

        elif self.answer_module == 'recurrent':
            self.W_ans_res_in = theano.shared(lasagne.init.Normal(0.1).sample(
                (self.dim, self.dim + self.vocab_size)),
                                              borrow=True)
            self.W_ans_res_hid = theano.shared(lasagne.init.Normal(0.1).sample(
                (self.dim, self.dim)),
                                               borrow=True)
            self.b_ans_res = theano.shared(lasagne.init.Constant(0.0).sample(
                (self.dim, )),
                                           borrow=True)

            self.W_ans_upd_in = theano.shared(lasagne.init.Normal(0.1).sample(
                (self.dim, self.dim + self.vocab_size)),
                                              borrow=True)
            self.W_ans_upd_hid = theano.shared(lasagne.init.Normal(0.1).sample(
                (self.dim, self.dim)),
                                               borrow=True)
            self.b_ans_upd = theano.shared(lasagne.init.Constant(0.0).sample(
                (self.dim, )),
                                           borrow=True)

            self.W_ans_hid_in = theano.shared(lasagne.init.Normal(0.1).sample(
                (self.dim, self.dim + self.vocab_size)),
                                              borrow=True)
            self.W_ans_hid_hid = theano.shared(lasagne.init.Normal(0.1).sample(
                (self.dim, self.dim)),
                                               borrow=True)
            self.b_ans_hid = theano.shared(lasagne.init.Constant(0.0).sample(
                (self.dim, )),
                                           borrow=True)

            def answer_step(prev_a, prev_y):
                a = self.GRU_update(prev_a, T.concatenate([prev_y, self.q_q]),
                                    self.W_ans_res_in, self.W_ans_res_hid,
                                    self.b_ans_res, self.W_ans_upd_in,
                                    self.W_ans_upd_hid, self.b_ans_upd,
                                    self.W_ans_hid_in, self.W_ans_hid_hid,
                                    self.b_ans_hid)

                y = nn_utils.softmax(T.dot(self.W_a, a))
                return [a, y]

            # TODO: add conditional ending
            dummy = theano.shared(np.zeros((self.vocab_size, ), dtype=floatX))
            results, updates = theano.scan(
                fn=answer_step,
                outputs_info=[last_mem, T.zeros_like(dummy)],
                n_steps=1)
            self.prediction = results[1][-1]

        else:
            raise Exception("invalid answer_module")

        print "==> collecting all parameters"
        self.params = [
            self.W_inp_res_in, self.W_inp_res_hid, self.b_inp_res,
            self.W_inp_upd_in, self.W_inp_upd_hid, self.b_inp_upd,
            self.W_inp_hid_in, self.W_inp_hid_hid, self.b_inp_hid,
            self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res,
            self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
            self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid, self.W_b,
            self.W_1, self.W_2, self.b_1, self.b_2, self.W_a
        ]

        if self.answer_module == 'recurrent':
            self.params = self.params + [
                self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res,
                self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid
            ]

        print "==> building loss layer and computing updates"
        self.loss_ce = T.nnet.categorical_crossentropy(
            self.prediction.dimshuffle('x', 0), T.stack([self.answer_var]))[0]
        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0

        self.loss = self.loss_ce + self.loss_l2

        updates = lasagne.updates.adadelta(self.loss, self.params)

        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(
                inputs=[
                    self.input_var, self.q_var, self.answer_var,
                    self.input_mask_var
                ],
                outputs=[self.prediction, self.loss],
                updates=updates)

        print "==> compiling test_fn"
        self.test_fn = theano.function(inputs=[
            self.input_var, self.q_var, self.answer_var, self.input_mask_var
        ],
                                       outputs=[
                                           self.prediction, self.loss,
                                           self.inp_c, self.q_q, last_mem
                                       ])

        if self.mode == 'train':
            print "==> computing gradients (for debugging)"
            gradient = T.grad(self.loss, self.params)
            self.get_gradient_fn = theano.function(inputs=[
                self.input_var, self.q_var, self.answer_var,
                self.input_mask_var
            ],
                                                   outputs=gradient)
Exemple #24
0
    def __init__(self, train_raw, dev_raw, test_raw, word2vec, word_vector_size, 
                dim, mode, input_mask_mode, memory_hops, l2, normalize_attention, dropout, **kwargs):
        print "==> model: GRU, dot similarity, training embedding"
        print "==> not used params in DMN class:", kwargs.keys()
        self.word2vec = word2vec      
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.mode = mode
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        #self.batch_size = 1
        self.l2 = l2
        self.normalize_attention = normalize_attention
        self.dropout = dropout
        
        self.train_input, self.train_q, self.train_answer, self.train_choices, self.train_input_mask = self._process_input(train_raw)
        self.dev_input, self.dev_q, self.dev_answer, self.dev_choices, self.dev_input_mask = self._process_input(dev_raw)
        self.test_input, self.test_q, self.test_answer, self.test_choices, self.test_input_mask = self._process_input(test_raw)
        self.attentions = []
        
        self.inp_var = T.ivector('input_var')
        self.q_var = T.ivector('question_var')
        self.ca_var = T.ivector('ca_var')
        self.cb_var = T.ivector('cb_var')
        self.cc_var = T.ivector('cc_var')
        self.cd_var = T.ivector('cd_var')
        self.ans_var = T.iscalar('answer_var')
        self.input_mask_var = T.ivector('input_mask_var')
        
        print "==> embedding layer"
        self.embed = theano.shared(self.word2vec)
        inp_mat = self.embed[self.inp_var]
        q_mat = self.embed[self.q_var]
        ca_mat = self.embed[self.ca_var]
        cb_mat = self.embed[self.cb_var]
        cc_mat = self.embed[self.cc_var]
        cd_mat = self.embed[self.cd_var]
            
        print "==> building input module"
        self.W_inp_res_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.word_vector_size)), borrow=True)
        self.W_inp_res_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_inp_res = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        self.W_inp_upd_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.word_vector_size)), borrow=True)
        self.W_inp_upd_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_inp_upd = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        self.W_inp_hid_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.word_vector_size)), borrow=True)
        self.W_inp_hid_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_inp_hid = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        inp_c_history, _ = theano.scan(fn=self.input_gru_step, 
                    sequences=inp_mat,
                    outputs_info=T.zeros_like(self.b_inp_hid))
        
        self.inp_c = inp_c_history.take(self.input_mask_var, axis=0)
        
        self.q_q, _ = theano.scan(fn=self.input_gru_step, 
                    sequences=q_mat,
                    outputs_info=T.zeros_like(self.b_inp_hid))

        self.q_q = self.q_q[-1]        
        
        print "==> creating parameters for memory module"
        self.W_mem_res_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.W_mem_res_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_mem_res = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        self.W_mem_upd_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.W_mem_upd_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_mem_upd = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        self.W_mem_hid_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.W_mem_hid_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_mem_hid = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        self.W_b = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.W_1 = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, 7 * self.dim)), borrow=True)
        self.W_2 = theano.shared(lasagne.init.Normal(0.1).sample((1, self.dim)), borrow=True)
        self.b_1 = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        self.b_2 = theano.shared(lasagne.init.Constant(0.0).sample((1,)), borrow=True)
        
        print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
        memory = [self.q_q.copy()] # (dim, 1)
        for iter in range(1, self.memory_hops + 1):
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(self.GRU_update(memory[iter - 1], current_episode,
                                          self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res, 
                                          self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                                          self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid))
                                      
        last_mem_raw = memory[-1].dimshuffle('x', 0) # (batch_size=1, dim)
        net = layers.InputLayer(shape=(1, self.dim), input_var=last_mem_raw)
        if self.dropout > 0 and self.mode == 'train':
            net = layers.DropoutLayer(net, p=self.dropout)
        last_mem = layers.get_output(net)[0]
        
        print "==> building options module"
        self.c_vecs = []
        for choice in [ca_mat, cb_mat, cc_mat, cd_mat]:
            history, _ = theano.scan(fn=self.input_gru_step, 
                    sequences=choice,
                    outputs_info=T.zeros_like(self.b_inp_hid))
            self.c_vecs.append(history[-1])        
        self.c_vecs = T.stack(self.c_vecs).transpose((1, 0)) # (dim, 4)
        
        print "==> building answer module"
        self.W_a = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.prediction = nn_utils.softmax(T.dot(T.dot(self.W_a, last_mem),self.c_vecs))
                
        print "==> collecting all parameters" # embedding matrix is not trained
        self.params = [self.embed,
                  self.W_inp_res_in, self.W_inp_res_hid, self.b_inp_res, 
                  self.W_inp_upd_in, self.W_inp_upd_hid, self.b_inp_upd,
                  self.W_inp_hid_in, self.W_inp_hid_hid, self.b_inp_hid,
                  self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res, 
                  self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                  self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid,
                  self.W_b, self.W_1, self.W_2, self.b_1, self.b_2, self.W_a]        
        
        print "==> building loss layer and computing updates"
        self.loss_ce = T.nnet.categorical_crossentropy(self.prediction.dimshuffle('x', 0), T.stack([self.ans_var]))[0]
        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0
        
        self.loss = self.loss_ce + self.loss_l2
        
        updates = lasagne.updates.adadelta(self.loss, self.params)
        
        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(inputs=[self.inp_var, self.q_var, self.ans_var,
                                                    self.ca_var, self.cb_var, self.cc_var, self.cd_var,
                                                    self.input_mask_var],
                                            allow_input_downcast = True,
                                            outputs=[self.prediction, self.loss],
                                            updates=updates)
            self.attentions = T.stack(self.attentions)
            
        print "==> compiling test_fn"
        self.test_fn = theano.function(inputs=[self.inp_var, self.q_var, self.ans_var,
                                               self.ca_var, self.cb_var, self.cc_var, self.cd_var,
                                               self.input_mask_var],
                                       allow_input_downcast = True,
                                       outputs=[self.prediction, self.loss, self.attentions, self.inp_c, self.q_q, last_mem])
        
        
        if self.mode == 'train':
            print "==> computing gradients (for debugging)"
            gradient = T.grad(self.loss, self.params)
            self.get_gradient_fn = theano.function(inputs=[self.inp_var, self.q_var, self.ans_var,
                                                           self.ca_var, self.cb_var, self.cc_var, self.cd_var,
                                                           self.input_mask_var],
                                                   allow_input_downcast = True,
                                                   outputs=gradient)
    def __init__(self,word2vec, word_vector_size, dim,
                mode, answer_module, input_mask_mode, memory_hops, batch_size, l2,
                normalize_attention, batch_norm, dropout,h5file,json_dict_file ,num_answers,img_vector_size,img_seq_len,
                img_h5file_train,img_h5file_test,**kwargs):

        print "==> not used params in DMN class:", kwargs.keys()

        self.vocab = {}
        self.ivocab = {}
        self.lr=0.001
        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.mode = mode
        self.answer_module = answer_module
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        self.batch_size = batch_size
        self.l2 = l2
        self.normalize_attention = normalize_attention
        self.batch_norm = batch_norm
        self.dropout = dropout
        self.h5file=h5py.File(h5file,"r")
        self.img_h5file_train=h5py.File(img_h5file_train,"r")
        self.img_h5file_test=h5py.File(img_h5file_test,"r")
        self.img_seq_len=img_seq_len

        self.img_vector_size=img_vector_size
        with open (json_dict_file) as f:
            self.json_dict=json.load(f)

        #self.train_input, self.train_q, self.train_answer, self.train_fact_count, self.train_input_mask  = self._process_input(babi_train_raw)
        #self.test_input, self.test_q, self.test_answer, self.test_fact_count, self.test_input_mask  = self._process_input(babi_test_raw)
        #self.vocab_size = len(self.vocab)
        self.vocab_size=num_answers

        self.input_var = T.tensor3('input_var') # (batch_size, seq_len, glove_dim)
        self.img_input_var=T.tensor3('img_input_var') # (batch_size * img_seq_len , img_vector_size)
        self.q_var = T.tensor3('question_var') # as self.input_var
        self.answer_var = T.ivector('answer_var') # answer of example in minibatch
        self.fact_count_var = T.ivector('fact_count_var') # number of facts in the example of minibatch
        self.input_mask_var = T.imatrix('input_mask_var') # (batch_size, indices)

        print "==> building input module"
        self.W_inp_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_inp_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_inp_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        input_var_shuffled = self.input_var.dimshuffle(1, 2, 0)
        inp_dummy = theano.shared(np.zeros((self.dim, self.batch_size), dtype=floatX))
        inp_c_history, _ = theano.scan(fn=self.input_gru_step,
                            sequences=input_var_shuffled,
                            outputs_info=T.zeros_like(inp_dummy))

        inp_c_history_shuffled = inp_c_history.dimshuffle(2, 0, 1)

        inp_c_list = []
        inp_c_mask_list = []
        for batch_index in range(self.batch_size):
            taken = inp_c_history_shuffled[batch_index].take(self.input_mask_var[batch_index, :self.fact_count_var[batch_index]], axis=0)
            inp_c_list.append(T.concatenate([taken, T.zeros((self.input_mask_var.shape[1] - taken.shape[0], self.dim), floatX)]))
            inp_c_mask_list.append(T.concatenate([T.ones((taken.shape[0],), np.int32), T.zeros((self.input_mask_var.shape[1] - taken.shape[0],), np.int32)]))

        self.inp_c = T.stack(inp_c_list).dimshuffle(1, 2, 0)
        inp_c_mask = T.stack(inp_c_mask_list).dimshuffle(1, 0)

###################### Adding the Image Input Module

        print "==> building image img_input module"
        ### Don't Really Need the GRU to reduce the sentences into vectors ###
        self.img_input_var=T.reshape(self.img_input_var , ( self.batch_size * self.img_seq_len , self.img_vector_size ))

        img_input_layer=layers.InputLayer( shape=(self.batch_size*self.img_seq_len, self.img_vector_size), input_var=self.img_input_var)

        ## Convert the img_vector_size to self.dim using a MLP ##
        img_input_layer=layers.DenseLayer( img_input_layer , num_units=self.dim )

        img_input_var_dim=layers.get_output(img_input_layer)

        img_input_var_dim=T.reshape(img_input_var_dim ,(self.batch_size , self.img_seq_len , self.dim )  )

        #self.img_inp_c = T.stack(img_input_var_dim).dimshuffle(1, 2, 0)

        self.img_inp_c = img_input_var_dim.dimshuffle(1,2,0)
###################################################
        q_var_shuffled = self.q_var.dimshuffle(1, 2, 0)
        q_dummy = theano.shared(np.zeros((self.dim, self.batch_size), dtype=floatX))
        q_q_history, _ = theano.scan(fn=self.input_gru_step,
                            sequences=q_var_shuffled,
                            outputs_info=T.zeros_like(q_dummy))
        self.q_q = q_q_history[-1]


        print "==> creating parameters for memory module"
        self.W_mem_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_mem_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_mem_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_1 = nn_utils.normal_param(std=0.1, shape=(self.dim, 7 * self.dim + 0))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1,))


        print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(self.GRU_update(memory[iter - 1], current_episode,
                                          self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res,
                                          self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                                          self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid))

        last_mem_raw = memory[-1].dimshuffle((1, 0))

################################# Episodic Memory Module for Image

        print "==> creating parameters for image memory module"
        self.W_img_mem_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_img_mem_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_img_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_img_mem_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_img_mem_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_img_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_img_mem_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_img_mem_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_img_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_img_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_img_1 = nn_utils.normal_param(std=0.1, shape=(self.dim, 7 * self.dim + 0))
        self.W_img_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_img_1 = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        self.b_img_2 = nn_utils.constant_param(value=0.0, shape=(1,))


        print "==> building episodic img_memory module (fixed number of steps: %d)" % self.memory_hops
        img_memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            current_episode = self.new_img_episode(img_memory[iter - 1])
            img_memory.append(self.GRU_update(img_memory[iter - 1], current_episode,
                                          self.W_img_mem_res_in, self.W_img_mem_res_hid, self.b_img_mem_res,
                                          self.W_img_mem_upd_in, self.W_img_mem_upd_hid, self.b_img_mem_upd,
                                          self.W_img_mem_hid_in, self.W_img_mem_hid_hid, self.b_img_mem_hid))

        last_img_mem_raw = img_memory[-1].dimshuffle((1, 0))




#######################################################################

        ### Concatenating The 2 Memory Modules Representations Assuming the representation as self.batch_size x self.dim  ###

        combined_mem_raw=T.concatenate([last_mem_raw,last_img_mem_raw],axis=1)

        #net = layers.InputLayer(shape=(self.batch_size, self.dim), input_var=last_mem_raw)

        net = layers.InputLayer(shape=(self.batch_size, self.dim+self.dim), input_var=combined_mem_raw)
        if self.batch_norm:
            net = layers.BatchNormLayer(incoming=net)
        if self.dropout > 0 and self.mode == 'train':
            net = layers.DropoutLayer(net, p=self.dropout)
        last_mem = layers.get_output(net).dimshuffle((1, 0))


        print "==> building answer module"
        #self.W_a = nn_utils.normal_param(std=0.1, shape=(self.vocab_size, self.dim))
        self.W_a = nn_utils.normal_param(std=0.1, shape=(self.vocab_size, self.dim+self.dim))
        if self.answer_module == 'feedforward':
            self.prediction = nn_utils.softmax(T.dot(self.W_a, last_mem))

        elif self.answer_module == 'recurrent':
            self.W_ans_res_in = nn_utils.normal_param(std=0.1, shape=(2*self.dim, self.dim + self.vocab_size))
            self.W_ans_res_hid = nn_utils.normal_param(std=0.1, shape=(2*self.dim, 2*self.dim))
            self.b_ans_res = nn_utils.constant_param(value=0.0, shape=(2*self.dim,))

            self.W_ans_upd_in = nn_utils.normal_param(std=0.1, shape=(2*self.dim, self.dim + self.vocab_size))
            self.W_ans_upd_hid = nn_utils.normal_param(std=0.1, shape=(2*self.dim,2*self.dim))
            self.b_ans_upd = nn_utils.constant_param(value=0.0, shape=(2*self.dim,))

            self.W_ans_hid_in = nn_utils.normal_param(std=0.1, shape=(2*self.dim, self.dim + self.vocab_size))
            self.W_ans_hid_hid = nn_utils.normal_param(std=0.1, shape=(2*self.dim, 2*self.dim))
            self.b_ans_hid = nn_utils.constant_param(value=0.0, shape=(2*self.dim,))

            def answer_step(prev_a, prev_y):
                a = self.GRU_update(prev_a, T.concatenate([prev_y, self.q_q]),
                                  self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res,
                                  self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                                  self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid)

                y = nn_utils.softmax(T.dot(self.W_a, a))
                return [a, y]

            # TODO: add conditional ending
            dummy = theano.shared(np.zeros((self.vocab_size, self.batch_size), dtype=floatX))
            results, updates = theano.scan(fn=answer_step,
                outputs_info=[last_mem, T.zeros_like(dummy)], #(last_mem, y)
                n_steps=1)
            self.prediction = results[1][-1]

        else:
            raise Exception("invalid answer_module")

        self.prediction = self.prediction.dimshuffle(1, 0)

        self.params = [self.W_inp_res_in, self.W_inp_res_hid, self.b_inp_res,
                  self.W_inp_upd_in, self.W_inp_upd_hid, self.b_inp_upd,
                  self.W_inp_hid_in, self.W_inp_hid_hid, self.b_inp_hid,
                  self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res,
                  self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                  self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid, #self.W_b
                  self.W_1, self.W_2, self.b_1, self.b_2, self.W_a,  ## Add the parameters of the Image Input Module
                  self.W_img_mem_res_in, self.W_img_mem_res_hid, self.b_img_mem_res,
                  self.W_img_mem_upd_in, self.W_img_mem_upd_hid, self.b_img_mem_upd,
                  self.W_img_mem_hid_in, self.W_img_mem_hid_hid, self.b_img_mem_hid, #self.W_img_b_img
                  self.W_img_1, self.W_img_2, self.b_img_1, self.b_img_2]  ## Add the parameters of the Image Input Module

        dim_transform_mlp_params=layers.get_all_params(img_input_layer )

        self.params=self.params+ dim_transform_mlp_params

        if self.answer_module == 'recurrent':
            self.params = self.params + [self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res,
                              self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                              self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid]


        print "==> building loss layer and computing updates"
        self.loss_ce = T.nnet.categorical_crossentropy(self.prediction, self.answer_var).mean()

        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0

        self.loss = self.loss_ce + self.loss_l2

        self.learning_rate=T.scalar(name="learning_rate")
        updates=lasagne.updates.adam(self.loss,self.params,learning_rate=self.learning_rate)
        #updates = lasagne.updates.adadelta(self.loss, self.params)
        #updates = lasagne.updates.momentum(self.loss, self.params, learning_rate=0.001)

        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var,
                                                    self.fact_count_var, self.input_mask_var,self.img_input_var,self.learning_rate],
                                            outputs=[self.prediction, self.loss],
                                            updates=updates)

        print "==> compiling test_fn"
        self.test_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var,
                                               self.fact_count_var, self.input_mask_var,self.img_input_var,self.learning_rate],on_unused_input='ignore',
                                       outputs=[self.prediction, self.loss])
Exemple #26
0
    def __init__(self, babi_train_raw, babi_test_raw, word2vec,
                 word_vector_size, dim, mode, answer_module, input_mask_mode,
                 memory_hops, l2, normalize_attention, batch_norm, dropout,
                 **kwargs):

        print "==> not used params in DMN class:", kwargs.keys()
        self.vocab = {}
        self.ivocab = {}

        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.mode = mode
        self.answer_module = answer_module
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        self.l2 = l2
        self.normalize_attention = normalize_attention
        self.batch_norm = batch_norm
        self.dropout = dropout

        self.train_input, self.train_q, self.train_answer, self.train_input_mask = self._process_input(
            babi_train_raw)
        self.test_input, self.test_q, self.test_answer, self.test_input_mask = self._process_input(
            babi_test_raw)
        self.vocab_size = len(self.vocab)

        self.input_var = T.matrix('input_var')
        self.q_var = T.matrix('question_var')
        self.answer_var = T.iscalar('answer_var')
        self.input_mask_var = T.ivector('input_mask_var')

        self.attentions = []

        print "==> building input module"
        self.W_inp_res_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_inp_upd_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_inp_hid_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        inp_c_history, _ = theano.scan(fn=self.input_gru_step,
                                       sequences=self.input_var,
                                       outputs_info=T.zeros_like(
                                           self.b_inp_hid))

        self.inp_c = inp_c_history.take(self.input_mask_var, axis=0)

        self.q_q, _ = theano.scan(fn=self.input_gru_step,
                                  sequences=self.q_var,
                                  outputs_info=T.zeros_like(self.b_inp_hid))

        self.q_q = self.q_q[-1]

        print "==> creating parameters for memory module"
        self.W_mem_res_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_upd_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_hid_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        #self.W_1 = nn_utils.normal_param(std=0.1, shape=(self.dim, 7 * self.dim + 0))
        self.W_111 = nn_utils.normal_param(std=0.1,
                                           shape=(self.dim, self.dim,
                                                  self.dim))
        self.W_112 = nn_utils.normal_param(std=0.1,
                                           shape=(self.dim, self.dim,
                                                  self.dim))
        self.W_113 = nn_utils.normal_param(std=0.1,
                                           shape=(self.dim, self.dim,
                                                  self.dim))
        self.W_12 = nn_utils.normal_param(std=0.1,
                                          shape=(self.dim, 3 * self.dim))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim, ))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1, ))

        print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(
                self.GRU_update(memory[iter - 1], current_episode,
                                self.W_mem_res_in, self.W_mem_res_hid,
                                self.b_mem_res, self.W_mem_upd_in,
                                self.W_mem_upd_hid, self.b_mem_upd,
                                self.W_mem_hid_in, self.W_mem_hid_hid,
                                self.b_mem_hid))

        last_mem_raw = memory[-1].dimshuffle(('x', 0))

        net = layers.InputLayer(shape=(1, self.dim), input_var=last_mem_raw)
        if self.dropout > 0 and self.mode == 'train':
            net = layers.DropoutLayer(net, p=self.dropout)
        last_mem = layers.get_output(net)[0]

        print "==> building answer module"
        self.W_a = nn_utils.normal_param(std=0.1,
                                         shape=(self.vocab_size, self.dim))

        if self.answer_module == 'feedforward':
            self.prediction = nn_utils.softmax(T.dot(self.W_a, last_mem))

        elif self.answer_module == 'recurrent':
            self.W_ans_res_in = nn_utils.normal_param(
                std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_res_hid = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
            self.b_ans_res = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

            self.W_ans_upd_in = nn_utils.normal_param(
                std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_upd_hid = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
            self.b_ans_upd = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

            self.W_ans_hid_in = nn_utils.normal_param(
                std=0.1, shape=(self.dim, self.dim + self.vocab_size))
            self.W_ans_hid_hid = nn_utils.normal_param(std=0.1,
                                                       shape=(self.dim,
                                                              self.dim))
            self.b_ans_hid = nn_utils.constant_param(value=0.0,
                                                     shape=(self.dim, ))

            def answer_step(prev_a, prev_y):
                a = self.GRU_update(prev_a, T.concatenate([prev_y, self.q_q]),
                                    self.W_ans_res_in, self.W_ans_res_hid,
                                    self.b_ans_res, self.W_ans_upd_in,
                                    self.W_ans_upd_hid, self.b_ans_upd,
                                    self.W_ans_hid_in, self.W_ans_hid_hid,
                                    self.b_ans_hid)

                y = nn_utils.softmax(T.dot(self.W_a, a))
                return [a, y]

            # TODO: add conditional ending
            dummy = theano.shared(np.zeros((self.vocab_size, ), dtype=floatX))
            results, updates = theano.scan(
                fn=answer_step,
                outputs_info=[last_mem, T.zeros_like(dummy)],
                n_steps=1)
            self.prediction = results[1][-1]

        else:
            raise Exception("invalid answer_module")

        print "==> collecting all parameters"
        self.params = [
            self.W_inp_res_in,
            self.W_inp_res_hid,
            self.b_inp_res,
            self.W_inp_upd_in,
            self.W_inp_upd_hid,
            self.b_inp_upd,
            self.W_inp_hid_in,
            self.W_inp_hid_hid,
            self.b_inp_hid,
            self.W_mem_res_in,
            self.W_mem_res_hid,
            self.b_mem_res,
            self.W_mem_upd_in,
            self.W_mem_upd_hid,
            self.b_mem_upd,
            self.W_mem_hid_in,
            self.W_mem_hid_hid,
            self.b_mem_hid,  #self.W_b
            #self.W_1,
            self.W_111,
            self.W_112,
            self.W_113,
            self.W_12,
            self.W_2,
            self.b_1,
            self.b_2,
            self.W_a
        ]

        if self.answer_module == 'recurrent':
            self.params = self.params + [
                self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res,
                self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid
            ]

        print "==> building loss layer and computing updates"
        self.loss_ce = T.nnet.categorical_crossentropy(
            self.prediction.dimshuffle('x', 0), T.stack([self.answer_var]))[0]

        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0

        self.loss = self.loss_ce + self.loss_l2

        updates = lasagne.updates.adadelta(self.loss, self.params)
        #updates = lasagne.updates.momentum(self.loss, self.params, learning_rate=0.0003)

        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(
                inputs=[
                    self.input_var, self.q_var, self.answer_var,
                    self.input_mask_var
                ],
                outputs=[self.prediction, self.loss],
                updates=updates)

        self.attentions = T.stack(self.attentions)
        print "==> compiling test_fn"
        self.test_fn = theano.function(
            inputs=[
                self.input_var, self.q_var, self.answer_var,
                self.input_mask_var
            ],
            outputs=[self.prediction, self.loss, self.attentions])
Exemple #27
0
    def __init__(self, train_raw, dev_raw, test_raw, word2vec, word_vector_size, 
                dim, mode, input_mask_mode, memory_hops, l2, lr, normalize_attention, dropout, **kwargs):
        print "==> model: GRU, pending options, training embedding"
        print "==> not used params in DMN class:", kwargs.keys()
        self.word2vec = word2vec      
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.mode = mode
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        #self.batch_size = 1
        self.l2 = l2
        self.lr = lr
        self.normalize_attention = normalize_attention
        self.dropout = dropout
        
        self.train_input, self.train_q, self.train_answer, self.train_choices, self.train_input_mask = self._process_input(train_raw)
        self.dev_input, self.dev_q, self.dev_answer, self.dev_choices, self.dev_input_mask = self._process_input(dev_raw)
        self.test_input, self.test_q, self.test_answer, self.test_choices, self.test_input_mask = self._process_input(test_raw)
        self.vocab_size = 4 # number of answer choices
        
        self.inp_var = T.ivector('input_var')
        self.q_var = T.ivector('question_var')
        self.ca_var = T.ivector('ca_var')
        self.cb_var = T.ivector('cb_var')
        self.cc_var = T.ivector('cc_var')
        self.cd_var = T.ivector('cd_var')
        self.ans_var = T.iscalar('answer_var')
        self.input_mask_var = T.ivector('input_mask_var')
        
        print "==> embedding layer"
        self.embed = theano.shared(self.word2vec)
        inp_mat = self.embed[self.inp_var]
        q_mat = self.embed[self.q_var]
        ca_mat = self.embed[self.ca_var]
        cb_mat = self.embed[self.cb_var]
        cc_mat = self.embed[self.cc_var]
        cd_mat = self.embed[self.cd_var]
        
        print "==> building input module"
        self.W_inp_res_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.word_vector_size)), borrow=True)
        self.W_inp_res_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_inp_res = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        self.W_inp_upd_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.word_vector_size)), borrow=True)
        self.W_inp_upd_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_inp_upd = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        self.W_inp_hid_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.word_vector_size)), borrow=True)
        self.W_inp_hid_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_inp_hid = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        inp_c_history, _ = theano.scan(fn=self.input_gru_step, 
                    sequences=inp_mat,
                    outputs_info=T.zeros_like(self.b_inp_hid))
        
        self.inp_c = inp_c_history.take(self.input_mask_var, axis=0)
        
        self.q_q, _ = theano.scan(fn=self.input_gru_step, 
                    sequences=q_mat,
                    outputs_info=T.zeros_like(self.b_inp_hid))

        self.q_q = self.q_q[-1]
        
        self.c_vecs = []
        for choice in [ca_mat, cb_mat, cc_mat, cd_mat]:
            history, _ = theano.scan(fn=self.input_gru_step, 
                    sequences=choice,
                    outputs_info=T.zeros_like(self.b_inp_hid))
            self.c_vecs.append(history[-1])
        
        
        self.c_vecs = T.stack(self.c_vecs).transpose((1, 0)) # (dim, 4)
        self.inp_c = T.stack([self.inp_c] * 4).transpose((1, 2, 0)) # (fact_cnt, dim, 4)
        self.q_q = T.stack([self.q_q] * 4).transpose((1, 0)) # (dim, 4)
        
        
        print "==> creating parameters for memory module"
        self.W_mem_res_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.W_mem_res_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_mem_res = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        self.W_mem_upd_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.W_mem_upd_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_mem_upd = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        self.W_mem_hid_in = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.W_mem_hid_hid = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.b_mem_hid = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        
        self.W_b = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, self.dim)), borrow=True)
        self.W_1 = theano.shared(lasagne.init.Normal(0.1).sample((self.dim, 10 * self.dim + 3)), borrow=True)
        self.W_2 = theano.shared(lasagne.init.Normal(0.1).sample((1, self.dim)), borrow=True)
        self.b_1 = theano.shared(lasagne.init.Constant(0.0).sample((self.dim,)), borrow=True)
        self.b_2 = theano.shared(lasagne.init.Constant(0.0).sample((1,)), borrow=True)
        

        print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
        memory = [self.q_q.copy()] # (dim, 4)
        for iter in range(1, self.memory_hops + 1):
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(self.GRU_update_batch(memory[iter - 1], current_episode,
                                          self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res, 
                                          self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                                          self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid))
                                      
        last_mem_raw = memory[-1].flatten().dimshuffle('x', 0) # (dim*4)
        net = layers.InputLayer(shape=(1, 4 * self.dim), input_var=last_mem_raw)
        if self.dropout > 0 and self.mode == 'train':
            net = layers.DropoutLayer(net, p=self.dropout)
        last_mem = layers.get_output(net)[0]

        print "==> building answer module"
        self.W_a = theano.shared(lasagne.init.Normal(0.1).sample((self.vocab_size, 4 * self.dim)), borrow=True)
        self.prediction = nn_utils.softmax(T.dot(self.W_a, last_mem))
        
        
        print "==> collecting all parameters" # embedding matrix is not trained
        self.params = [self.embed,
                  self.W_inp_res_in, self.W_inp_res_hid, self.b_inp_res, 
                  self.W_inp_upd_in, self.W_inp_upd_hid, self.b_inp_upd,
                  self.W_inp_hid_in, self.W_inp_hid_hid, self.b_inp_hid,
                  self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res, 
                  self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                  self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid,
                  self.W_b, self.W_1, self.W_2, self.b_1, self.b_2, self.W_a]
        
        
        print "==> building loss layer and computing updates"
        self.loss_ce = T.nnet.categorical_crossentropy(self.prediction.dimshuffle('x', 0), T.stack([self.ans_var]))[0]
        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0
        
        self.loss = self.loss_ce + self.loss_l2
        
        updates = lasagne.updates.adadelta(self.loss, self.params, learning_rate=self.lr)
        
        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(inputs=[self.inp_var, self.q_var, self.ans_var,
                                                    self.ca_var, self.cb_var, self.cc_var, self.cd_var,
                                                    self.input_mask_var],
                                            allow_input_downcast = True,
                                            outputs=[self.prediction, self.loss],
                                            updates=updates)
        
        print "==> compiling test_fn"
        self.test_fn = theano.function(inputs=[self.inp_var, self.q_var, self.ans_var,
                                               self.ca_var, self.cb_var, self.cc_var, self.cd_var,
                                               self.input_mask_var],
                                       allow_input_downcast = True,
                                       outputs=[self.prediction, self.loss, self.inp_c, self.q_q, last_mem])
        
        
        if self.mode == 'train':
            print "==> computing gradients (for debugging)"
            gradient = T.grad(self.loss, self.params)
            self.get_gradient_fn = theano.function(inputs=[self.inp_var, self.q_var, self.ans_var,
                                                           self.ca_var, self.cb_var, self.cc_var, self.cd_var,
                                                           self.input_mask_var],
                                                   allow_input_downcast = True,
                                                   outputs=gradient)
Exemple #28
0
    def __init__(self, babi_train_raw, babi_test_raw, word2vec, word_vector_size, 
                memory_hops, dim, mode, input_mask_mode, l2, batch_norm, dropout, **kwargs):

        print "==> not used params in DMN class:", kwargs.keys()
        self.vocab = {}
        self.ivocab = {}
        
        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.mode = mode
        self.input_mask_mode = input_mask_mode
        self.l2 = l2
        self.batch_norm = batch_norm
        self.dropout = dropout
        self.memory_hops = memory_hops
        
        self.train_input, self.train_q, self.train_answer, self.train_input_mask, self.train_gates = self._process_input(babi_train_raw)
        self.test_input, self.test_q, self.test_answer, self.test_input_mask, self.test_gates = self._process_input(babi_test_raw)
        self.vocab_size = len(self.vocab)
        
        print "Train size: ", len(self.train_input)
        print "Test size: ", len(self.test_input)
        print "Vocab size: ", self.vocab_size
        
        self.input_var = T.matrix('input_var')
        self.q_var = T.matrix('question_var')
        self.answer_var = T.iscalar('answer_var')
        self.input_mask_var = T.ivector('input_mask_var')
        self.gates_var = T.ivector('gates_var') # attention gate (including end_reading)
        
        self.attentions = []
            
        print "==> building input module"
        self.W_inp_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_inp_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_inp_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        inp_c_history, _ = theano.scan(fn=self.input_gru_step, 
                    sequences=self.input_var,
                    outputs_info=T.zeros_like(self.b_inp_hid))
        self.end_reading = nn_utils.constant_param(value=0.0,shape=(1,self.dim))           
        inp_c_tag = T.concatenate([inp_c_history,self.end_reading],axis=0)
        
        self.inp_c = inp_c_tag.take(self.input_mask_var, axis=0) #(facts_len,dim)
        
        self.q_q, _ = theano.scan(fn=self.input_gru_step, 
                    sequences=self.q_var,
                    outputs_info=T.zeros_like(self.b_inp_hid))
        self.q_q = self.q_q[-1] #(1,dim)
        
        
        print "==> creating parameters for memory module"
        self.W_mem_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_mem_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_mem_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_1 = nn_utils.normal_param(std=0.1, shape=(self.dim, 7 * self.dim + 2))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1,))


        print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
        memory = [self.q_q.copy()]
        for iter in range(0, self.memory_hops):
            current_episode, g = self.new_episode(memory[iter])
            self.attentions.append(g)
            memory.append(self.GRU_update(memory[iter], current_episode,
                                          self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res, 
                                          self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                                          self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid))
        
        last_mem_raw = memory[-1].dimshuffle(('x', 0))
       
        net = layers.InputLayer(shape=(1, self.dim), input_var=last_mem_raw)
        if self.dropout > 0 and self.mode == 'train':
            net = layers.DropoutLayer(net, p=self.dropout)
        last_mem = layers.get_output(net)[0]
               
        self.attentions = T.stack(self.attentions) #(memory_hops, fact_cnt)
        
        print "==> building answer module"
        self.W_a = nn_utils.normal_param(std=0.1, shape=(self.vocab_size, self.dim))
        self.prediction = nn_utils.softmax(T.dot(self.W_a, last_mem))        
        
        print "==> collecting all parameters"
        self.params = [self.W_inp_res_in, self.W_inp_res_hid, self.b_inp_res, 
                  self.W_inp_upd_in, self.W_inp_upd_hid, self.b_inp_upd,
                  self.W_inp_hid_in, self.W_inp_hid_hid, self.b_inp_hid,
                  self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res, 
                  self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                  self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid, self.W_b,
                  self.W_1, self.W_2, self.b_1, self.b_2, self.W_a]        
        
        print "==> building loss layer and computing updates"
        self.loss_ce = T.nnet.categorical_crossentropy(self.prediction.dimshuffle('x', 0), 
                                                       T.stack([self.answer_var]))[0]

        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0
        
        self.loss_gate = T.nnet.categorical_crossentropy(self.attentions, self.gates_var).mean()
        
        self.loss = self.loss_ce + self.loss_l2 + self.loss_gate
        
        updates = lasagne.updates.adam(self.loss, self.params)
        #updates = lasagne.updates.momentum(self.loss, self.params, learning_rate=0.0003)
        
        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.input_mask_var, self.gates_var], 
                                            allow_input_downcast = True,
                                            outputs=[self.prediction, self.loss, self.attentions],
                                            updates=updates)
        
        print "==> compiling test_fn"
        self.test_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.input_mask_var, self.gates_var],
                                       allow_input_downcast = True,
                                       outputs=[self.prediction, self.loss, self.attentions])

        if self.mode == 'train':
            print "==> computing gradients (for debugging)"
            gradient = T.grad(self.loss, self.params)
            self.get_gradient_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.input_mask_var, self.gates_var],
                                                           allow_input_downcast = True, outputs=gradient)        
    def __init__(self, train_raw, dev_raw, test_raw, word2vec, word_vector_size, 
                dim, mode, input_mask_mode, memory_hops, l2, normalize_attention, dropout, **kwargs):
        print "generate one-word answer for mctest"
        print "==> not used params in DMN class:", kwargs.keys()
        self.word2vec = word2vec      
        self.word_vector_size = word_vector_size
        self.vocab_size = len(word2vec)
        
        self.dim = dim # hidden state size
        self.mode = mode
        self.input_mask_mode = input_mask_mode
        self.memory_hops = memory_hops
        self.l2 = l2
        self.normalize_attention = normalize_attention
        self.dropout = dropout
        
        self.train_input, self.train_q, self.train_answer, self.train_input_mask = self._process_input(train_raw)
        self.dev_input, self.dev_q, self.dev_answer, self.dev_input_mask = self._process_input(dev_raw)
        self.test_input, self.test_q, self.test_answer, self.test_input_mask = self._process_input(test_raw)
        
        self.input_var = T.matrix('input_var')
        self.q_var = T.matrix('question_var')
        self.answer_var = T.iscalar('answer_var')
        self.input_mask_var = T.ivector('input_mask_var')
        self.attentions = []
            
        print "==> building input module"
        self.W_inp_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_inp_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_inp_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_inp_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        inp_c_history, _ = theano.scan(fn=self.input_gru_step, 
                    sequences=self.input_var,
                    outputs_info=T.zeros_like(self.b_inp_hid))
        
        self.inp_c = inp_c_history.take(self.input_mask_var, axis=0)
        
        self.q_q, _ = theano.scan(fn=self.input_gru_step, 
                    sequences=self.q_var,
                    outputs_info=T.zeros_like(self.b_inp_hid))

        self.q_q = self.q_q[-1]
        
        
        print "==> creating parameters for memory module"
        self.W_mem_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_mem_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_mem_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_1 = nn_utils.normal_param(std=0.1, shape=(self.dim, 7 * self.dim + 2))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1,))


        print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            current_episode = self.new_episode(memory[iter - 1])
            memory.append(self.GRU_update(memory[iter - 1], current_episode,
                                          self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res, 
                                          self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                                          self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid))
        
        last_mem_raw = memory[-1].dimshuffle(('x', 0))
        
        net = layers.InputLayer(shape=(1, self.dim), input_var=last_mem_raw)
        if self.dropout > 0 and self.mode == 'train':
            net = layers.DropoutLayer(net, p=self.dropout)
        last_mem = layers.get_output(net)[0]
        self.attentions = T.stack(self.attentions)
        
        print "==> building answer module"
        self.W_a = nn_utils.normal_param(std=0.1, shape=(self.vocab_size, self.dim))
        
        self.prediction = nn_utils.softmax(T.dot(self.W_a, last_mem))
        
        print "==> collecting all parameters"
        self.params = [self.W_inp_res_in, self.W_inp_res_hid, self.b_inp_res, 
                  self.W_inp_upd_in, self.W_inp_upd_hid, self.b_inp_upd,
                  self.W_inp_hid_in, self.W_inp_hid_hid, self.b_inp_hid,
                  self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res, 
                  self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                  self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid, self.W_b,
                  self.W_1, self.W_2, self.b_1, self.b_2, self.W_a]
        
        print "==> building loss layer and computing updates"
        self.loss_ce = T.nnet.categorical_crossentropy(self.prediction.dimshuffle('x',0),T.stack([self.answer_var]))[0]

        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0
        
        self.loss = self.loss_ce + self.loss_l2
        
        updates = lasagne.updates.adam(self.loss, self.params)
        #updates = lasagne.updates.momentum(self.loss, self.params, learning_rate=0.0003)
        
        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.input_mask_var], 
                                            allow_input_downcast = True,
                                            outputs=[self.prediction, self.loss],
                                            updates=updates)
        
        print "==> compiling test_fn"
        self.test_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.input_mask_var],
                                       allow_input_downcast = True,
                                       outputs=[self.prediction, self.loss, self.attentions])
Exemple #30
0
    def __init__(self, babi_train_raw, babi_test_raw, word2vec,
                 word_vector_size, memory_hops, dim, mode, input_mask_mode, l2,
                 batch_norm, dropout, **kwargs):

        print "==> not used params in DMN class:", kwargs.keys()
        self.vocab = {}
        self.ivocab = {}

        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.mode = mode
        self.input_mask_mode = input_mask_mode
        self.l2 = l2
        self.batch_norm = batch_norm
        self.dropout = dropout
        self.memory_hops = memory_hops

        self.train_input, self.train_q, self.train_answer, self.train_input_mask, self.train_gates = self._process_input(
            babi_train_raw)
        self.test_input, self.test_q, self.test_answer, self.test_input_mask, self.test_gates = self._process_input(
            babi_test_raw)
        self.vocab_size = len(self.vocab)

        print "Train size: ", len(self.train_input)
        print "Test size: ", len(self.test_input)
        print "Vocab size: ", self.vocab_size

        self.input_var = T.matrix('input_var')
        self.q_var = T.matrix('question_var')
        self.answer_var = T.iscalar('answer_var')
        self.input_mask_var = T.ivector('input_mask_var')
        self.gates_var = T.ivector(
            'gates_var')  # attention gate (including end_reading)

        self.attentions = []

        print "==> building input module"
        self.W_inp_res_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_inp_upd_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_inp_hid_in = nn_utils.normal_param(
            std=0.1, shape=(self.dim, self.word_vector_size))
        self.W_inp_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_inp_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        inp_c_history, _ = theano.scan(fn=self.input_gru_step,
                                       sequences=self.input_var,
                                       outputs_info=T.zeros_like(
                                           self.b_inp_hid))
        self.end_reading = nn_utils.constant_param(value=0.0,
                                                   shape=(1, self.dim))
        inp_c_tag = T.concatenate([inp_c_history, self.end_reading], axis=0)

        self.inp_c = inp_c_tag.take(self.input_mask_var,
                                    axis=0)  #(facts_len,dim)

        self.q_q, _ = theano.scan(fn=self.input_gru_step,
                                  sequences=self.q_var,
                                  outputs_info=T.zeros_like(self.b_inp_hid))
        self.q_q = self.q_q[-1]  #(1,dim)

        print "==> creating parameters for memory module"
        self.W_mem_res_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_upd_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_hid_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_1 = nn_utils.normal_param(std=0.1,
                                         shape=(self.dim, 7 * self.dim + 2))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim, ))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1, ))

        print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
        memory = [self.q_q.copy()]
        for iter in range(0, self.memory_hops):
            current_episode, g = self.new_episode(memory[iter])
            self.attentions.append(g)
            memory.append(
                self.GRU_update(memory[iter], current_episode,
                                self.W_mem_res_in, self.W_mem_res_hid,
                                self.b_mem_res, self.W_mem_upd_in,
                                self.W_mem_upd_hid, self.b_mem_upd,
                                self.W_mem_hid_in, self.W_mem_hid_hid,
                                self.b_mem_hid))

        last_mem_raw = memory[-1].dimshuffle(('x', 0))

        net = layers.InputLayer(shape=(1, self.dim), input_var=last_mem_raw)
        if self.dropout > 0 and self.mode == 'train':
            net = layers.DropoutLayer(net, p=self.dropout)
        last_mem = layers.get_output(net)[0]

        self.attentions = T.stack(self.attentions)  #(memory_hops, fact_cnt)

        print "==> building answer module"
        self.W_a = nn_utils.normal_param(std=0.1,
                                         shape=(self.vocab_size, self.dim))
        self.prediction = nn_utils.softmax(T.dot(self.W_a, last_mem))

        print "==> collecting all parameters"
        self.params = [
            self.W_inp_res_in, self.W_inp_res_hid, self.b_inp_res,
            self.W_inp_upd_in, self.W_inp_upd_hid, self.b_inp_upd,
            self.W_inp_hid_in, self.W_inp_hid_hid, self.b_inp_hid,
            self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res,
            self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
            self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid, self.W_b,
            self.W_1, self.W_2, self.b_1, self.b_2, self.W_a
        ]

        print "==> building loss layer and computing updates"
        self.loss_ce = T.nnet.categorical_crossentropy(
            self.prediction.dimshuffle('x', 0), T.stack([self.answer_var]))[0]

        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0

        self.loss_gate = T.nnet.categorical_crossentropy(
            self.attentions, self.gates_var).mean()

        self.loss = self.loss_ce + self.loss_l2 + self.loss_gate

        updates = lasagne.updates.adam(self.loss, self.params)
        #updates = lasagne.updates.momentum(self.loss, self.params, learning_rate=0.0003)

        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(
                inputs=[
                    self.input_var, self.q_var, self.answer_var,
                    self.input_mask_var, self.gates_var
                ],
                allow_input_downcast=True,
                outputs=[self.prediction, self.loss, self.attentions],
                updates=updates)

        print "==> compiling test_fn"
        self.test_fn = theano.function(
            inputs=[
                self.input_var, self.q_var, self.answer_var,
                self.input_mask_var, self.gates_var
            ],
            allow_input_downcast=True,
            outputs=[self.prediction, self.loss, self.attentions])

        if self.mode == 'train':
            print "==> computing gradients (for debugging)"
            gradient = T.grad(self.loss, self.params)
            self.get_gradient_fn = theano.function(inputs=[
                self.input_var, self.q_var, self.answer_var,
                self.input_mask_var, self.gates_var
            ],
                                                   allow_input_downcast=True,
                                                   outputs=gradient)