Exemple #1
0
    def __init__(self, data_dir, word2vec, word_vector_size, dim, cnn_dim,
                 story_len, mode, answer_module, memory_hops, batch_size, l2,
                 normalize_attention, batch_norm, dropout, **kwargs):

        print "==> not used params in DMN class:", kwargs.keys()

        self.data_dir = data_dir

        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.cnn_dim = cnn_dim
        self.story_len = story_len
        self.mode = mode
        self.answer_module = answer_module
        self.memory_hops = memory_hops
        self.batch_size = batch_size
        self.l2 = l2
        self.normalize_attention = normalize_attention
        self.batch_norm = batch_norm
        self.dropout = dropout

        self.vocab, self.ivocab = self._load_vocab(self.data_dir)

        self.train_story = None
        self.test_story = None
        self.train_dict_story, self.train_lmdb_env_fc = self._process_input_sind_lmdb(
            self.data_dir, 'train')
        self.val_dict_story, self.val_lmdb_env_fc = self._process_input_sind_lmdb(
            self.data_dir, 'val')
        self.test_dict_story, self.test_lmdb_env_fc = self._process_input_sind_lmdb(
            self.data_dir, 'test')

        self.train_story = self.train_dict_story.keys()
        self.val_story = self.val_dict_story.keys()
        self.test_story = self.test_dict_story.keys()
        self.vocab_size = len(self.vocab)

        self.q_var = T.tensor3(
            'q_var')  # Now, it's a batch * story_len * image_sieze.
        self.answer_var = T.imatrix(
            'answer_var')  # answer of example in minibatch
        self.answer_mask = T.matrix('answer_mask')
        self.answer_inp_var = T.tensor3(
            'answer_inp_var')  # answer of example in minibatch

        print "==> building input module"
        # It's very simple now, the input module just need to map from cnn_dim to dim.
        logging.info('self.cnn_dim = %d', self.cnn_dim)
        self.W_inp_emb_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim,
                                                         self.cnn_dim))
        self.b_inp_emb_in = nn_utils.constant_param(value=0.0,
                                                    shape=(self.dim, ))

        q_seq = self.q_var.dimshuffle(0, 'x', 1, 2)
        q_seq_rpt = T.repeat(q_seq, self.story_len, 1)
        q_seq_rhp = T.reshape(q_seq_rpt,
                              (q_seq_rpt.shape[0] * q_seq_rpt.shape[1],
                               q_seq_rpt.shape[2], q_seq_rpt.shape[3]))

        inp_var_shuffled = q_seq_rhp.dimshuffle(1, 2, 0)  #seq x cnn x batch

        def _dot(x, W, b):
            return T.dot(W, x) + b.dimshuffle(0, 'x')

        inp_c_hist, _ = theano.scan(
            fn=_dot,
            sequences=inp_var_shuffled,
            non_sequences=[self.W_inp_emb_in, self.b_inp_emb_in])
        #inp_c_hist,_ = theano.scan(fn = _dot, sequences=self.input_var, non_sequences = [self.W_inp_emb_in, self.b_inp_emb_in])

        self.inp_c = inp_c_hist  # seq x emb x batch

        print "==> building question module"
        # Now, share the parameter with the input module.
        q_var_shuffled = self.q_var.dimshuffle(
            1, 2, 0)  # now: story_len * image_size * batch_size

        # This is the RNN used to produce the Global Glimpse
        self.W_inpf_res_in = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim,
                                                          self.cnn_dim))
        self.W_inpf_res_hid = nn_utils.normal_param(std=0.1,
                                                    shape=(self.dim, self.dim))
        self.b_inpf_res = nn_utils.constant_param(value=0.0,
                                                  shape=(self.dim, ))

        self.W_inpf_upd_in = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim,
                                                          self.cnn_dim))
        self.W_inpf_upd_hid = nn_utils.normal_param(std=0.1,
                                                    shape=(self.dim, self.dim))
        self.b_inpf_upd = nn_utils.constant_param(value=0.0,
                                                  shape=(self.dim, ))

        self.W_inpf_hid_in = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim,
                                                          self.cnn_dim))
        self.W_inpf_hid_hid = nn_utils.normal_param(std=0.1,
                                                    shape=(self.dim, self.dim))
        self.b_inpf_hid = nn_utils.constant_param(value=0.0,
                                                  shape=(self.dim, ))
        inp_dummy = theano.shared(
            np.zeros((self.dim, self.batch_size), dtype=floatX))

        q_glb, _ = theano.scan(fn=self.input_gru_step_forward,
                               sequences=q_var_shuffled,
                               outputs_info=[T.zeros_like(inp_dummy)])
        q_glb_shuffled = q_glb.dimshuffle(2, 0,
                                          1)  # batch_size * seq_len * dim
        q_glb_last = q_glb_shuffled[:, -1, :]  # batch_size * dim

        # Now, we also need to build the individual model.
        #q_var_shuffled = self.q_var.dimshuffle(1,0)
        q_single = T.reshape(
            self.q_var,
            (self.q_var.shape[0] * self.q_var.shape[1], self.q_var.shape[2]))
        q_single_shuffled = q_single.dimshuffle(1,
                                                0)  #cnn_dim x batch_size * 5

        # batch_size * 5 x dim
        q_hist = T.dot(self.W_inp_emb_in,
                       q_single_shuffled) + self.b_inp_emb_in.dimshuffle(
                           0, 'x')
        q_hist_shuffled = q_hist.dimshuffle(1, 0)  # batch_size * 5 x dim

        if self.batch_norm:
            logging.info("Using batch normalization.")
        q_net = layers.InputLayer(shape=(self.batch_size * self.story_len,
                                         self.dim),
                                  input_var=q_hist_shuffled)
        if self.batch_norm:
            q_net = layers.BatchNormLayer(incoming=q_net)
        if self.dropout > 0 and self.mode == 'train':
            q_net = layers.DropoutLayer(q_net, p=self.dropout)
        #last_mem = layers.get_output(q_net).dimshuffle((1, 0))
        self.q_q = layers.get_output(q_net).dimshuffle(1, 0)

        print "==> creating parameters for memory module"
        self.W_mem_res_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_upd_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_hid_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_1 = nn_utils.normal_param(std=0.1,
                                         shape=(self.dim, 7 * self.dim + 0))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim, ))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1, ))
        print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            #m = printing.Print('mem')(memory[iter-1])
            current_episode = self.new_episode(memory[iter - 1])
            #current_episode = self.new_episode(m)
            #current_episode = printing.Print('current_episode')(current_episode)
            memory.append(
                self.GRU_update(memory[iter - 1], current_episode,
                                self.W_mem_res_in, self.W_mem_res_hid,
                                self.b_mem_res, self.W_mem_upd_in,
                                self.W_mem_upd_hid, self.b_mem_upd,
                                self.W_mem_hid_in, self.W_mem_hid_hid,
                                self.b_mem_hid))

        last_mem_raw = memory[-1].dimshuffle((1, 0))

        net = layers.InputLayer(shape=(self.batch_size * self.story_len,
                                       self.dim),
                                input_var=last_mem_raw)

        if self.batch_norm:
            net = layers.BatchNormLayer(incoming=net)
        if self.dropout > 0 and self.mode == 'train':
            net = layers.DropoutLayer(net, p=self.dropout)
        last_mem = layers.get_output(net).dimshuffle((1, 0))

        print "==> building answer module"

        answer_inp_var_shuffled = self.answer_inp_var.dimshuffle(1, 2, 0)
        # Sounds good. Now, we need to map last_mem to a new space.
        self.W_mem_emb = nn_utils.normal_param(std=0.1,
                                               shape=(self.dim, self.dim * 3))
        self.W_inp_emb = nn_utils.normal_param(std=0.1,
                                               shape=(self.dim,
                                                      self.vocab_size + 1))

        def _dot2(x, W):
            return T.dot(W, x)

        answer_inp_var_shuffled_emb, _ = theano.scan(
            fn=_dot2,
            sequences=answer_inp_var_shuffled,
            non_sequences=self.W_inp_emb)  # seq x dim x batch

        # dim x batch_size * 5
        q_glb_dim = q_glb_last.dimshuffle(0, 'x', 1)  # batch_size * 1 * dim
        q_glb_repmat = T.repeat(q_glb_dim, self.story_len,
                                1)  # batch_size * len * dim
        q_glb_rhp = T.reshape(q_glb_repmat,
                              (q_glb_repmat.shape[0] * q_glb_repmat.shape[1],
                               q_glb_repmat.shape[2]))
        init_ans = T.concatenate(
            [self.q_q, last_mem,
             q_glb_rhp.dimshuffle(1, 0)], axis=0)

        mem_ans = T.dot(self.W_mem_emb, init_ans)  # dim x batchsize.
        mem_ans_dim = mem_ans.dimshuffle('x', 0, 1)
        answer_inp = T.concatenate([mem_ans_dim, answer_inp_var_shuffled_emb],
                                   axis=0)

        dummy = theano.shared(
            np.zeros((self.dim, self.batch_size * self.story_len),
                     dtype=floatX))

        self.W_a = nn_utils.normal_param(std=0.1,
                                         shape=(self.vocab_size + 1, self.dim))

        self.W_ans_res_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_ans_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_ans_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_ans_upd_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_ans_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_ans_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_ans_hid_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_ans_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_ans_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        results, _ = theano.scan(fn=self.answer_gru_step,
                                 sequences=answer_inp,
                                 outputs_info=[dummy])
        # Assume there is a start token
        #print results.shape.eval({self.input_var: np.random.rand(10,4,4096).astype('float32'),
        #    self.q_var: np.random.rand(10, 4096).astype('float32'),
        #    self.answer_inp_var: np.random.rand(10, 18, 8001).astype('float32')}, on_unused_input='ignore')
        #results = results[1:-1,:,:] # get rid of the last token as well as the first one (image)
        #print results.shape.eval({self.input_var: np.random.rand(10,4,4096).astype('float32'),
        #    self.q_var: np.random.rand(10, 4096).astype('float32'),
        #    self.answer_inp_var: np.random.rand(10, 18, 8001).astype('float32')}, on_unused_input='ignore')

        # Now, we need to transform it to the probabilities.

        prob, _ = theano.scan(fn=lambda x, w: T.dot(w, x),
                              sequences=results,
                              non_sequences=self.W_a)
        preds = prob[1:, :, :]
        prob = prob[1:-1, :, :]

        prob_shuffled = prob.dimshuffle(2, 0, 1)  # b * len * vocab
        preds_shuffled = preds.dimshuffle(2, 0, 1)

        logging.info("prob shape.")
        #print prob.shape.eval({self.input_var: np.random.rand(10,4,4096).astype('float32'),
        #    self.q_var: np.random.rand(10, 4096).astype('float32'),
        #    self.answer_inp_var: np.random.rand(10, 18, 8001).astype('float32')})

        n = prob_shuffled.shape[0] * prob_shuffled.shape[1]
        n_preds = preds_shuffled.shape[0] * preds_shuffled.shape[1]

        prob_rhp = T.reshape(prob_shuffled, (n, prob_shuffled.shape[2]))
        preds_rhp = T.reshape(preds_shuffled,
                              (n_preds, preds_shuffled.shape[2]))

        prob_sm = nn_utils.softmax_(prob_rhp)
        preds_sm = nn_utils.softmax_(preds_rhp)
        self.prediction = prob_sm  # this one is for the training.

        # This one is for the beamsearch.
        self.pred = T.reshape(
            preds_sm, (preds_shuffled.shape[0], preds_shuffled.shape[1],
                       preds_shuffled.shape[2]))

        mask = T.reshape(self.answer_mask, (n, ))
        lbl = T.reshape(self.answer_var, (n, ))

        self.params = [
            self.W_inp_emb_in,
            self.b_inp_emb_in,
            self.W_mem_res_in,
            self.W_mem_res_hid,
            self.b_mem_res,
            self.W_mem_upd_in,
            self.W_mem_upd_hid,
            self.b_mem_upd,
            self.W_mem_hid_in,
            self.W_mem_hid_hid,
            self.b_mem_hid,  #self.W_b
            self.W_1,
            self.W_2,
            self.b_1,
            self.b_2,
            self.W_a
        ]

        self.params = self.params + [
            self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res,
            self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
            self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid,
            self.W_mem_emb, self.W_inp_emb
        ]

        print "==> building loss layer and computing updates"
        loss_vec = T.nnet.categorical_crossentropy(prob_sm, lbl)
        self.loss_ce = (mask * loss_vec).sum() / mask.sum()

        #self.loss_ce = T.nnet.categorical_crossentropy(results_rhp, lbl)

        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0

        self.loss = self.loss_ce + self.loss_l2

        updates = lasagne.updates.adadelta(self.loss, self.params)
        #updates = lasagne.updates.momentum(self.loss, self.params, learning_rate=0.001)

        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(
                inputs=[
                    self.q_var, self.answer_var, self.answer_mask,
                    self.answer_inp_var
                ],
                outputs=[self.prediction, self.loss],
                updates=updates)

        print "==> compiling test_fn"
        self.test_fn = theano.function(inputs=[
            self.q_var, self.answer_var, self.answer_mask, self.answer_inp_var
        ],
                                       outputs=[self.prediction, self.loss])

        print "==> compiling pred_fn"
        self.pred_fn = theano.function(
            inputs=[self.q_var, self.answer_inp_var], outputs=[self.pred])
    def __init__(self, data_dir, word2vec, word_vector_size, truncate_gradient,
                 learning_rate, dim, cnn_dim, cnn_dim_fc, story_len, patches,
                 mode, answer_module, memory_hops, batch_size, l2,
                 normalize_attention, batch_norm, dropout, **kwargs):

        print "==> not used params in DMN class:", kwargs.keys()

        self.data_dir = data_dir
        self.learning_rate = learning_rate

        self.truncate_gradient = truncate_gradient
        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.cnn_dim = cnn_dim
        self.cnn_dim_fc = cnn_dim_fc
        self.story_len = story_len
        self.mode = mode
        self.patches = patches
        self.answer_module = answer_module
        self.memory_hops = memory_hops
        self.batch_size = batch_size
        self.l2 = l2
        self.normalize_attention = normalize_attention
        self.batch_norm = batch_norm
        self.dropout = dropout

        #self.vocab, self.ivocab = self._load_vocab(self.data_dir)
        self.vocab, self.ivocab = self._ext_vocab_from_word2vec()

        self.train_story = None
        self.test_story = None
        self.train_dict_story, self.train_lmdb_env_fc, self.train_lmdb_env_conv = self._process_input_sind(
            self.data_dir, 'train')
        self.test_dict_story, self.test_lmdb_env_fc, self.test_lmdb_env_conv = self._process_input_sind(
            self.data_dir, 'val')

        self.train_story = self.train_dict_story.keys()
        self.test_story = self.test_dict_story.keys()
        self.vocab_size = len(self.vocab)

        # Since this is pretty expensive, we will pass a story each time.
        # We assume that the input has been processed such that the sequences of patches
        # are snake like path.

        self.input_var = T.tensor4(
            'input_var')  # (batch_size, seq_len, patches, cnn_dim)
        self.q_var = T.matrix('q_var')  # Now, it's a batch * image_sieze.
        self.answer_var = T.imatrix(
            'answer_var')  # answer of example in minibatch
        self.answer_mask = T.matrix('answer_mask')
        self.answer_inp_var = T.tensor3(
            'answer_inp_var')  # answer of example in minibatch

        print "==> building input module"
        self.W_inp_emb_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim,
                                                         self.cnn_dim))
        #self.b_inp_emb_in = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        # First, we embed the visual features before sending it to the bi-GRUs.

        inp_rhp = T.reshape(
            self.input_var,
            (self.batch_size * self.story_len * self.patches, self.cnn_dim))
        inp_rhp_dimshuffled = inp_rhp.dimshuffle(1, 0)
        inp_rhp_emb = T.dot(self.W_inp_emb_in, inp_rhp_dimshuffled)
        inp_rhp_emb_dimshuffled = inp_rhp_emb.dimshuffle(1, 0)
        inp_emb_raw = T.reshape(
            inp_rhp_emb_dimshuffled,
            (self.batch_size, self.story_len, self.patches, self.cnn_dim))
        inp_emb = T.tanh(
            inp_emb_raw
        )  # Just follow the paper DMN for visual and textual QA.

        # Now, we use a bi-directional GRU to produce the input.
        # Forward GRU.
        self.inp_dim = self.dim / 2  # since we have forward and backward
        self.W_inpf_res_in = nn_utils.normal_param(std=0.1,
                                                   shape=(self.inp_dim,
                                                          self.cnn_dim))
        self.W_inpf_res_hid = nn_utils.normal_param(std=0.1,
                                                    shape=(self.inp_dim,
                                                           self.inp_dim))
        self.b_inpf_res = nn_utils.constant_param(value=0.0,
                                                  shape=(self.inp_dim, ))

        self.W_inpf_upd_in = nn_utils.normal_param(std=0.1,
                                                   shape=(self.inp_dim,
                                                          self.cnn_dim))
        self.W_inpf_upd_hid = nn_utils.normal_param(std=0.1,
                                                    shape=(self.inp_dim,
                                                           self.inp_dim))
        self.b_inpf_upd = nn_utils.constant_param(value=0.0,
                                                  shape=(self.inp_dim, ))

        self.W_inpf_hid_in = nn_utils.normal_param(std=0.1,
                                                   shape=(self.inp_dim,
                                                          self.cnn_dim))
        self.W_inpf_hid_hid = nn_utils.normal_param(std=0.1,
                                                    shape=(self.inp_dim,
                                                           self.inp_dim))
        self.b_inpf_hid = nn_utils.constant_param(value=0.0,
                                                  shape=(self.inp_dim, ))
        # Backward GRU.
        self.W_inpb_res_in = nn_utils.normal_param(std=0.1,
                                                   shape=(self.inp_dim,
                                                          self.cnn_dim))
        self.W_inpb_res_hid = nn_utils.normal_param(std=0.1,
                                                    shape=(self.inp_dim,
                                                           self.inp_dim))
        self.b_inpb_res = nn_utils.constant_param(value=0.0,
                                                  shape=(self.inp_dim, ))

        self.W_inpb_upd_in = nn_utils.normal_param(std=0.1,
                                                   shape=(self.inp_dim,
                                                          self.cnn_dim))
        self.W_inpb_upd_hid = nn_utils.normal_param(std=0.1,
                                                    shape=(self.inp_dim,
                                                           self.inp_dim))
        self.b_inpb_upd = nn_utils.constant_param(value=0.0,
                                                  shape=(self.inp_dim, ))

        self.W_inpb_hid_in = nn_utils.normal_param(std=0.1,
                                                   shape=(self.inp_dim,
                                                          self.cnn_dim))
        self.W_inpb_hid_hid = nn_utils.normal_param(std=0.1,
                                                    shape=(self.inp_dim,
                                                           self.inp_dim))
        self.b_inpb_hid = nn_utils.constant_param(value=0.0,
                                                  shape=(self.inp_dim, ))

        # Now, we use the GRU to build the inputs.
        # Two-level of nested scan is unnecessary. It will become too complicated. Just use this one.
        inp_dummy = theano.shared(
            np.zeros((self.inp_dim, self.story_len), dtype=floatX))
        for i in range(self.batch_size):
            if i == 0:
                inp_1st_f, _ = theano.scan(
                    fn=self.input_gru_step_forward,
                    sequences=inp_emb[i, :].dimshuffle(1, 2, 0),
                    outputs_info=T.zeros_like(inp_dummy),
                    truncate_gradient=self.truncate_gradient)

                inp_1st_b, _ = theano.scan(
                    fn=self.input_gru_step_backward,
                    sequences=inp_emb[i, :, ::-1, :].dimshuffle(1, 2, 0),
                    outputs_info=T.zeros_like(inp_dummy),
                    truncate_gradient=self.truncate_gradient)
                # Now, combine them.
                inp_1st = T.concatenate([
                    inp_1st_f.dimshuffle(2, 0, 1),
                    inp_1st_b.dimshuffle(2, 0, 1)
                ],
                                        axis=-1)
                self.inp_c = inp_1st.dimshuffle('x', 0, 1, 2)
            else:
                inp_f, _ = theano.scan(
                    fn=self.input_gru_step_forward,
                    sequences=inp_emb[i, :].dimshuffle(1, 2, 0),
                    outputs_info=T.zeros_like(inp_dummy),
                    truncate_gradient=self.truncate_gradient)

                inp_b, _ = theano.scan(
                    fn=self.input_gru_step_backward,
                    sequences=inp_emb[i, :, ::-1, :].dimshuffle(1, 2, 0),
                    outputs_info=T.zeros_like(inp_dummy),
                    truncate_gradient=self.truncate_gradient)
                # Now, combine them.
                inp_fb = T.concatenate(
                    [inp_f.dimshuffle(2, 0, 1),
                     inp_b.dimshuffle(2, 0, 1)],
                    axis=-1)
                self.inp_c = T.concatenate(
                    [self.inp_c, inp_fb.dimshuffle('x', 0, 1, 2)], axis=0)
        # Done, now self.inp_c should be batch_size x story_len x patches x cnn_dim
        # Eventually, we can flattern them.
        # Now, the input dimension is 1024 because we have forward and backward.
        inp_c_t = T.reshape(
            self.inp_c,
            (self.batch_size, self.story_len * self.patches, self.dim))
        inp_c_t_dimshuffled = inp_c_t.dimshuffle(0, 'x', 1, 2)
        inp_batch = T.repeat(inp_c_t_dimshuffled, self.story_len, axis=1)
        # Now, its ready for all the 5 images in the same story.
        # 50 * 980 * 512
        self.inp_batch = T.reshape(inp_batch,
                                   (inp_batch.shape[0] * inp_batch.shape[1],
                                    inp_batch.shape[2], inp_batch.shape[3]))
        self.inp_batch_dimshuffled = self.inp_batch.dimshuffle(
            1, 2, 0)  # 980 x 512 x 50

        # It's very simple now, the input module just need to map from cnn_dim to dim.
        logging.info('self.cnn_dim = %d', self.cnn_dim)

        print "==> building question module"
        # Now, share the parameter with the input module.
        self.W_inp_emb_q = nn_utils.normal_param(std=0.1,
                                                 shape=(self.dim,
                                                        self.cnn_dim_fc))
        self.b_inp_emb_q = nn_utils.normal_param(std=0.1, shape=(self.dim, ))
        q_var_shuffled = self.q_var.dimshuffle(1, 0)

        inp_q = T.dot(
            self.W_inp_emb_q, q_var_shuffled) + self.b_inp_emb_q.dimshuffle(
                0, 'x')  # 512 x 50
        self.q_q = T.tanh(
            inp_q
        )  # Since this is used to initialize the memory, we need to make it tanh.

        print "==> creating parameters for memory module"
        self.W_mem_res_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_upd_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_hid_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        #self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_1 = nn_utils.normal_param(std=0.1,
                                         shape=(self.dim, 7 * self.dim + 0))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim, ))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1, ))

        print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            #m = printing.Print('mem')(memory[iter-1])
            current_episode = self.new_episode(memory[iter - 1])
            #current_episode = self.new_episode(m)
            #current_episode = printing.Print('current_episode')(current_episode)
            memory.append(
                self.GRU_update(memory[iter - 1], current_episode,
                                self.W_mem_res_in, self.W_mem_res_hid,
                                self.b_mem_res, self.W_mem_upd_in,
                                self.W_mem_upd_hid, self.b_mem_upd,
                                self.W_mem_hid_in, self.W_mem_hid_hid,
                                self.b_mem_hid))

        last_mem_raw = memory[-1].dimshuffle((1, 0))

        net = layers.InputLayer(shape=(self.batch_size * self.story_len,
                                       self.dim),
                                input_var=last_mem_raw)

        if self.batch_norm:
            net = layers.BatchNormLayer(incoming=net)
        if self.dropout > 0 and self.mode == 'train':
            net = layers.DropoutLayer(net, p=self.dropout)
        last_mem = layers.get_output(net).dimshuffle((1, 0))

        logging.info('last_mem size')
        print last_mem.shape.eval({
            self.input_var:
            np.random.rand(10, 5, 196, 512).astype('float32'),
            self.q_var:
            np.random.rand(50, 4096).astype('float32')
        })

        print "==> building answer module"

        answer_inp_var_shuffled = self.answer_inp_var.dimshuffle(1, 2, 0)
        # Sounds good. Now, we need to map last_mem to a new space.
        self.W_mem_emb = nn_utils.normal_param(std=0.1,
                                               shape=(self.dim, self.dim * 2))
        self.W_inp_emb = nn_utils.normal_param(std=0.1,
                                               shape=(self.dim,
                                                      self.word_vector_size))

        def _dot2(x, W):
            return T.dot(W, x)

        answer_inp_var_shuffled_emb, _ = theano.scan(
            fn=_dot2,
            sequences=answer_inp_var_shuffled,
            non_sequences=self.W_inp_emb)  # seq x dim x batch

        # Now, we also need to embed the image and use it to do the memory.
        #q_q_shuffled = self.q_q.dimshuffle(1,0) # dim * batch.
        init_ans = T.concatenate([self.q_q, last_mem], axis=0)

        mem_ans = T.dot(self.W_mem_emb, init_ans)  # dim x batchsize.

        mem_ans_dim = mem_ans.dimshuffle('x', 0, 1)

        answer_inp = T.concatenate([mem_ans_dim, answer_inp_var_shuffled_emb],
                                   axis=0)

        # Now, we have both embedding. We can let them go to the rnn.

        # We also need to map the input layer as well.

        dummy = theano.shared(
            np.zeros((self.dim, self.batch_size * self.story_len),
                     dtype=floatX))

        self.W_a = nn_utils.normal_param(std=0.1,
                                         shape=(self.vocab_size + 1, self.dim))

        self.W_ans_res_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_ans_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_ans_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_ans_upd_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_ans_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_ans_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_ans_hid_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_ans_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_ans_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        logging.info('answer_inp size')

        #print answer_inp.shape.eval({self.input_var: np.random.rand(10,4,4096).astype('float32'),
        #    self.answer_inp_var: np.random.rand(10, 18, 8001).astype('float32'),
        #    self.q_var: np.random.rand(10, 4096).astype('float32')})

        #last_mem = printing.Print('prob_sm')(last_mem)
        results, _ = theano.scan(fn=self.answer_gru_step,
                                 sequences=answer_inp,
                                 outputs_info=[dummy])
        # Assume there is a start token
        #print results.shape.eval({self.input_var: np.random.rand(10,4,4096).astype('float32'),
        #    self.q_var: np.random.rand(10, 4096).astype('float32'),
        #    self.answer_inp_var: np.random.rand(10, 18, 8001).astype('float32')}, on_unused_input='ignore')
        results = results[
            1:
            -1, :, :]  # get rid of the last token as well as the first one (image)
        #print results.shape.eval({self.input_var: np.random.rand(10,4,4096).astype('float32'),
        #    self.q_var: np.random.rand(10, 4096).astype('float32'),
        #    self.answer_inp_var: np.random.rand(10, 18, 8001).astype('float32')}, on_unused_input='ignore')

        # Now, we need to transform it to the probabilities.

        prob, _ = theano.scan(fn=lambda x, w: T.dot(w, x),
                              sequences=results,
                              non_sequences=self.W_a)

        prob_shuffled = prob.dimshuffle(2, 0, 1)  # b * len * vocab

        logging.info("prob shape.")
        #print prob.shape.eval({self.input_var: np.random.rand(10,4,4096).astype('float32'),
        #    self.q_var: np.random.rand(10, 4096).astype('float32'),
        #    self.answer_inp_var: np.random.rand(10, 18, 8001).astype('float32')})

        n = prob_shuffled.shape[0] * prob_shuffled.shape[1]
        prob_rhp = T.reshape(prob_shuffled, (n, prob_shuffled.shape[2]))
        prob_sm = nn_utils.softmax_(prob_rhp)
        self.prediction = prob_sm

        mask = T.reshape(self.answer_mask, (n, ))
        lbl = T.reshape(self.answer_var, (n, ))

        self.params = [
            self.W_inp_emb_in,  #self.b_inp_emb_in, 
            self.W_inpf_res_in,
            self.W_inpf_res_hid,
            self.b_inpf_res,
            self.W_inpf_upd_in,
            self.W_inpf_upd_hid,
            self.b_inpf_upd,
            self.W_inpf_hid_in,
            self.W_inpf_hid_hid,
            self.b_inpf_hid,
            self.W_inpb_res_in,
            self.W_inpb_res_hid,
            self.b_inpb_res,
            self.W_inpb_upd_in,
            self.W_inpb_upd_hid,
            self.b_inpb_upd,
            self.W_inpb_hid_in,
            self.W_inpb_hid_hid,
            self.b_inpb_hid,
            self.W_inp_emb_q,
            self.b_inp_emb_q,
            self.W_mem_res_in,
            self.W_mem_res_hid,
            self.b_mem_res,
            self.W_mem_upd_in,
            self.W_mem_upd_hid,
            self.b_mem_upd,
            self.W_mem_hid_in,
            self.W_mem_hid_hid,
            self.b_mem_hid,  #self.W_b
            self.W_1,
            self.W_2,
            self.b_1,
            self.b_2,
            self.W_a,
            self.W_mem_emb,
            self.W_inp_emb,
            self.W_ans_res_in,
            self.W_ans_res_hid,
            self.b_ans_res,
            self.W_ans_upd_in,
            self.W_ans_upd_hid,
            self.b_ans_upd,
            self.W_ans_hid_in,
            self.W_ans_hid_hid,
            self.b_ans_hid,
        ]

        print "==> building loss layer and computing updates"
        loss_vec = T.nnet.categorical_crossentropy(prob_sm, lbl)
        self.loss_ce = (mask * loss_vec).sum() / mask.sum()

        #self.loss_ce = T.nnet.categorical_crossentropy(results_rhp, lbl)

        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0

        self.loss = self.loss_ce + self.loss_l2

        #updates = lasagne.updates.adadelta(self.loss, self.params)
        #updates = lasagne.updates.adam(self.loss, self.params, learning_rate = self.learning_rate)
        updates = lasagne.updates.rmsprop(self.loss,
                                          self.params,
                                          learning_rate=self.learning_rate)
        #updates = lasagne.updates.momentum(self.loss, self.params, learning_rate=0.001)

        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(
                inputs=[
                    self.input_var, self.q_var, self.answer_var,
                    self.answer_mask, self.answer_inp_var
                ],
                outputs=[self.prediction, self.loss],
                updates=updates)
            #profile = True)

        print "==> compiling test_fn"
        self.test_fn = theano.function(inputs=[
            self.input_var, self.q_var, self.answer_var, self.answer_mask,
            self.answer_inp_var
        ],
                                       outputs=[self.prediction, self.loss])
    def __init__(self, data_dir, word2vec, truncate_gradient, learning_rate,
                 dim, cnn_dim, story_len, mode, answer_module, batch_size, l2,
                 dropout, **kwargs):

        print "==> not used params in DMN class:", kwargs.keys()

        self.data_dir = data_dir
        self.learning_rate = learning_rate
        self.word_vector_size = 300

        self.truncate_gradient = truncate_gradient
        self.word2vec = word2vec
        self.dim = dim
        self.cnn_dim = cnn_dim
        self.story_len = story_len
        self.mode = mode
        self.answer_module = answer_module
        self.batch_size = batch_size
        self.l2 = l2
        self.dropout = dropout

        self.vocab, self.ivocab = self._load_vocab(self.data_dir)

        self.train_story = None
        self.test_story = None
        self.train_dict_story, self.train_lmdb_env_fc = self._process_input_sind(
            self.data_dir, 'train')
        self.test_dict_story, self.test_lmdb_env_fc = self._process_input_sind(
            self.data_dir, 'val')

        self.train_story = self.train_dict_story.keys()
        self.test_story = self.test_dict_story.keys()
        self.vocab_size = len(self.vocab)

        self.q_var = T.tensor3(
            'q_var')  # Now, it's a batch * story_len * image_sieze.
        self.answer_var = T.imatrix(
            'answer_var')  # answer of example in minibatch
        self.answer_mask = T.matrix('answer_mask')
        self.answer_inp_var = T.tensor3(
            'answer_inp_var')  # answer of example in minibatch

        q_shuffled = self.q_var.dimshuffle(
            1, 2, 0)  # now: story_len * image_size * batch_size

        print "==> building input module"

        # Now, we use a GRU to produce the input.
        # Forward GRU.
        self.W_inpf_res_in = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim,
                                                          self.cnn_dim))
        self.W_inpf_res_hid = nn_utils.normal_param(std=0.1,
                                                    shape=(self.dim, self.dim))
        self.b_inpf_res = nn_utils.constant_param(value=0.0,
                                                  shape=(self.dim, ))

        self.W_inpf_upd_in = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim,
                                                          self.cnn_dim))
        self.W_inpf_upd_hid = nn_utils.normal_param(std=0.1,
                                                    shape=(self.dim, self.dim))
        self.b_inpf_upd = nn_utils.constant_param(value=0.0,
                                                  shape=(self.dim, ))

        self.W_inpf_hid_in = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim,
                                                          self.cnn_dim))
        self.W_inpf_hid_hid = nn_utils.normal_param(std=0.1,
                                                    shape=(self.dim, self.dim))
        self.b_inpf_hid = nn_utils.constant_param(value=0.0,
                                                  shape=(self.dim, ))

        # Now, we use the GRU to build the inputs.
        # Two-level of nested scan is unnecessary. It will become too complicated. Just use this one.
        inp_dummy = theano.shared(
            np.zeros((self.dim, self.batch_size), dtype=floatX))

        q_inp, _ = theano.scan(fn=self.input_gru_step_forward,
                               sequences=q_shuffled,
                               outputs_info=[T.zeros_like(inp_dummy)])
        q_inp_shuffled = q_inp.dimshuffle(2, 0, 1)
        q_inp_last = q_inp_shuffled[:, -1, :].dimshuffle(1,
                                                         0)  #dim * batch_size

        # Now, share the parameter with the input module.
        self.W_inp_emb_q = nn_utils.normal_param(std=0.1,
                                                 shape=(self.dim, self.dim))
        self.b_inp_emb_q = nn_utils.normal_param(std=0.1, shape=(self.dim, ))

        inp_q = T.dot(
            self.W_inp_emb_q, q_inp_last) + self.b_inp_emb_q.dimshuffle(
                0, 'x')  # 512 x 50
        self.q_q = T.tanh(
            inp_q
        )  # Since this is used to initialize the memory, we need to make it tanh.

        print "==> building answer module"

        answer_inp_var_shuffled = self.answer_inp_var.dimshuffle(1, 2, 0)
        # Sounds good. Now, we need to map last_mem to a new space.
        self.W_inp_emb = nn_utils.normal_param(std=0.1,
                                               shape=(self.dim,
                                                      self.vocab_size + 1))

        def _dot2(x, W):
            return T.dot(W, x)

        answer_inp_var_shuffled_emb, _ = theano.scan(
            fn=_dot2,
            sequences=answer_inp_var_shuffled,
            non_sequences=self.W_inp_emb)  # seq x dim x batch

        self.W_a = nn_utils.normal_param(std=0.1,
                                         shape=(self.vocab_size + 1, self.dim))

        self.W_ans_res_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_ans_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_ans_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_ans_upd_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_ans_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_ans_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_ans_hid_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_ans_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_ans_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        results, _ = theano.scan(fn=self.answer_gru_step,
                                 sequences=answer_inp_var_shuffled_emb,
                                 outputs_info=[self.q_q])
        # Assume there is a start token
        #print results.shape.eval({self.input_var: np.random.rand(10,4,4096).astype('float32'),
        #    self.q_var: np.random.rand(10, 4096).astype('float32'),
        #    self.answer_inp_var: np.random.rand(10, 18, 8001).astype('float32')}, on_unused_input='ignore')
        #results = results[:-1,:,:] # get rid of the last token as well as the first one (image)
        #print results.shape.eval({self.input_var: np.random.rand(10,4,4096).astype('float32'),
        #    self.q_var: np.random.rand(10, 4096).astype('float32'),
        #    self.answer_inp_var: np.random.rand(10, 18, 8001).astype('float32')}, on_unused_input='ignore')

        # Now, we need to transform it to the probabilities.

        prob_f, _ = theano.scan(fn=lambda x, w: T.dot(w, x),
                                sequences=results,
                                non_sequences=self.W_a)
        prob = prob_f[:-1, :, :]

        prob_shuffled = prob.dimshuffle(2, 0, 1)  # b * len * vocab
        prob_f_shuffled = prob_f.dimshuffle(2, 0, 1)

        logging.info("prob shape.")
        #print prob.shape.eval({self.input_var: np.random.rand(10,4,4096).astype('float32'),
        #    self.q_var: np.random.rand(10, 4096).astype('float32'),
        #    self.answer_inp_var: np.random.rand(10, 18, 8001).astype('float32')})

        n = prob_shuffled.shape[0] * prob_shuffled.shape[1]
        prob_rhp = T.reshape(prob_shuffled, (n, prob_shuffled.shape[2]))
        prob_sm = nn_utils.softmax_(prob_rhp)
        self.pred = prob_sm

        n_f = prob_f_shuffled.shape[0] * prob_f_shuffled.shape[1]
        prob_f_rhp = T.reshape(prob_f_shuffled,
                               (n_f, prob_f_shuffled.shape[2]))

        prob_f_sm = nn_utils.softmax_(prob_f_rhp)
        self.prediction = T.reshape(
            prob_f_sm, (prob_f_shuffled.shape[0], prob_f_shuffled.shape[1],
                        prob_f_shuffled.shape[2]))

        mask = T.reshape(self.answer_mask, (n, ))
        lbl = T.reshape(self.answer_var, (n, ))

        self.params = [
            self.W_inpf_res_in,
            self.W_inpf_res_hid,
            self.b_inpf_res,
            self.W_inpf_upd_in,
            self.W_inpf_upd_hid,
            self.b_inpf_upd,
            self.W_inpf_hid_in,
            self.W_inpf_hid_hid,
            self.b_inpf_hid,
            self.W_inp_emb_q,
            self.b_inp_emb_q,
            self.W_a,
            self.W_inp_emb,
            self.W_ans_res_in,
            self.W_ans_res_hid,
            self.b_ans_res,
            self.W_ans_upd_in,
            self.W_ans_upd_hid,
            self.b_ans_upd,
            self.W_ans_hid_in,
            self.W_ans_hid_hid,
            self.b_ans_hid,
        ]

        print "==> building loss layer and computing updates"
        loss_vec = T.nnet.categorical_crossentropy(prob_sm, lbl)
        self.loss_ce = (mask * loss_vec).sum() / mask.sum()

        #self.loss_ce = T.nnet.categorical_crossentropy(results_rhp, lbl)

        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0

        self.loss = self.loss_ce + self.loss_l2

        updates = lasagne.updates.adadelta(self.loss,
                                           self.params,
                                           learning_rate=self.learning_rate)
        #updates = lasagne.updates.momentum(self.loss, self.params, learning_rate=0.001)

        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(inputs=[
                self.q_var, self.answer_var, self.answer_mask,
                self.answer_inp_var
            ],
                                            outputs=[self.pred, self.loss],
                                            updates=updates)

        print "==> compiling test_fn"
        self.test_fn = theano.function(inputs=[
            self.q_var, self.answer_var, self.answer_mask, self.answer_inp_var
        ],
                                       outputs=[self.pred, self.loss])

        print "==> compiling pred_fn"
        self.pred_fn = theano.function(
            inputs=[self.q_var, self.answer_inp_var],
            outputs=[self.prediction])
Exemple #4
0
    def __init__(self, data_dir, word2vec, word_vector_size, dim, cnn_dim,
                 story_len, patches, cnn_dim_fc, truncate_gradient,
                 learning_rate, mode, answer_module, memory_hops, batch_size,
                 l2, normalize_attention, batch_norm, dropout, **kwargs):

        print "==> not used params in DMN class:", kwargs.keys()

        self.data_dir = data_dir
        self.truncate_gradient = truncate_gradient
        self.learning_rate = learning_rate

        self.trng = RandomStreams(1234)

        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.cnn_dim = cnn_dim
        self.cnn_dim_fc = cnn_dim_fc
        self.story_len = story_len
        self.patches = patches
        self.mode = mode
        self.answer_module = answer_module
        self.memory_hops = memory_hops
        self.batch_size = batch_size
        self.l2 = l2
        self.normalize_attention = normalize_attention
        self.batch_norm = batch_norm
        self.dropout = dropout

        self.vocab, self.ivocab = self._load_vocab(self.data_dir)

        self.train_story = None
        self.test_story = None
        self.train_dict_story, self.train_lmdb_env_fc, self.train_lmdb_env_conv = self._process_input_sind_lmdb(
            self.data_dir, 'train')
        self.test_dict_story, self.test_lmdb_env_fc, self.test_lmdb_env_conv = self._process_input_sind_lmdb(
            self.data_dir, 'val')

        self.train_story = self.train_dict_story.keys()
        self.test_story = self.test_dict_story.keys()
        self.vocab_size = len(self.vocab)

        # This is the local patch of each image.
        self.input_var = T.tensor4(
            'input_var')  # (batch_size, seq_len, patches, cnn_dim)
        self.q_var = T.tensor3(
            'q_var')  # Now, it's a batch * story_len * image_sieze.
        self.answer_var = T.ivector(
            'answer_var')  # answer of example in minibatch
        self.answer_mask = T.matrix('answer_mask')
        self.answer_idx = T.imatrix('answer_idx')  # batch x seq
        self.answer_inp_var = T.tensor3(
            'answer_inp_var')  # answer of example in minibatch

        print "==> building input module"
        # It's very simple now, the input module just need to map from cnn_dim to dim.
        logging.info('self.cnn_dim = %d', self.cnn_dim)
        logging.info('self.cnn_dim_fc = %d', self.cnn_dim_fc)
        logging.info('self.dim = %d', self.dim)
        self.W_q_emb_in = nn_utils.normal_param(std=0.1,
                                                shape=(self.dim,
                                                       self.cnn_dim_fc))
        self.b_q_emb_in = nn_utils.constant_param(value=0.0,
                                                  shape=(self.dim, ))

        q_var_shuffled = self.q_var.dimshuffle(1, 2, 0)  # seq x cnn x batch.

        def _dot(x, W, b):
            return T.tanh(T.dot(W, x) + b.dimshuffle(0, 'x'))

        q_var_shuffled_emb, _ = theano.scan(
            fn=_dot,
            sequences=q_var_shuffled,
            non_sequences=[self.W_q_emb_in, self.b_q_emb_in])
        #print 'q_var_shuffled_emb', q_var_shuffled_emb.shape.eval({self.q_var:np.random.rand(2,5,4096).astype('float32')})
        q_var_emb = q_var_shuffled_emb.dimshuffle(2, 0,
                                                  1)  # batch x seq x emb_size
        q_var_emb_ext = q_var_emb.dimshuffle(0, 'x', 1, 2)
        q_var_emb_ext = T.repeat(q_var_emb_ext, q_var_emb.shape[1],
                                 1)  # batch x seq x seq x emb_size
        q_var_emb_rhp = T.reshape(
            q_var_emb,
            (q_var_emb.shape[0] * q_var_emb.shape[1], q_var_emb.shape[2]))
        q_var_emb_ext_rhp = T.reshape(
            q_var_emb_ext, (q_var_emb_ext.shape[0] * q_var_emb_ext.shape[1],
                            q_var_emb_ext.shape[2], q_var_emb_ext.shape[3]))
        q_var_emb_ext_rhp = q_var_emb_ext_rhp.dimshuffle(0, 2, 1)
        q_idx = T.arange(self.story_len).dimshuffle('x', 0)
        q_idx = T.repeat(q_idx, self.batch_size, axis=0)
        q_idx = T.reshape(q_idx, (q_idx.shape[0] * q_idx.shape[1], ))

        self.W_inp_emb_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim,
                                                         self.cnn_dim))
        self.b_inp_emb_in = nn_utils.constant_param(value=0.0,
                                                    shape=(self.dim, ))

        inp_rhp = T.reshape(
            self.input_var,
            (self.batch_size * self.story_len * self.patches, self.cnn_dim))
        inp_rhp_dimshuffled = inp_rhp.dimshuffle(1, 0)
        inp_rhp_emb = T.dot(
            self.W_inp_emb_in,
            inp_rhp_dimshuffled) + self.b_inp_emb_in.dimshuffle(0, 'x')
        inp_rhp_emb_dimshuffled = inp_rhp_emb.dimshuffle(1, 0)
        inp_emb_raw = T.reshape(
            inp_rhp_emb_dimshuffled,
            (self.batch_size * self.story_len, self.patches, self.cnn_dim))
        inp_emb = T.tanh(
            inp_emb_raw
        )  # Just follow the paper DMN for visual and textual QA.

        self.inp_c = inp_emb.dimshuffle(1, 2, 0)

        logging.info('building question module')
        self.W_qf_res_in = nn_utils.normal_param(std=0.1,
                                                 shape=(self.dim, self.dim))
        self.W_qf_res_hid = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.b_qf_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_qf_upd_in = nn_utils.normal_param(std=0.1,
                                                 shape=(self.dim, self.dim))
        self.W_qf_upd_hid = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.b_qf_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_qf_hid_in = nn_utils.normal_param(std=0.1,
                                                 shape=(self.dim, self.dim))
        self.W_qf_hid_hid = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.b_qf_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        inp_dummy = theano.shared(
            np.zeros((self.dim, self.batch_size), dtype=floatX))

        q_var_shuffled_emb_reversed = q_var_shuffled_emb[::
                                                         -1, :, :]  # seq x emb_size x batch
        q_glb, _ = theano.scan(fn=self.q_gru_step_forward,
                               sequences=q_var_shuffled_emb_reversed,
                               outputs_info=[T.zeros_like(inp_dummy)])
        q_glb_shuffled = q_glb.dimshuffle(2, 0,
                                          1)  # batch_size * seq_len * dim
        q_glb_last = q_glb_shuffled[:, -1, :]  # batch_size * dim

        q_net = layers.InputLayer(shape=(self.batch_size * self.story_len,
                                         self.dim),
                                  input_var=q_var_emb_rhp)
        if self.batch_norm:
            q_net = layers.BatchNormLayer(incoming=q_net)
        if self.dropout > 0 and self.mode == 'train':
            q_net = layers.DropoutLayer(q_net, p=self.dropout)
        self.q_q = layers.get_output(q_net).dimshuffle(1, 0)

        #print "==> creating parameters for memory module"
        logging.info('creating parameters for memory module')
        self.W_mem_res_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_upd_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_hid_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_mem_update1 = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim,
                                                          self.dim * 2))
        self.b_mem_upd1 = nn_utils.constant_param(value=0.0,
                                                  shape=(self.dim, ))
        self.W_mem_update2 = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim,
                                                          self.dim * 2))
        self.b_mem_upd2 = nn_utils.constant_param(value=0.0,
                                                  shape=(self.dim, ))
        self.W_mem_update3 = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim,
                                                          self.dim * 2))
        self.b_mem_upd3 = nn_utils.constant_param(value=0.0,
                                                  shape=(self.dim, ))

        self.W_mem_update = [
            self.W_mem_update1, self.W_mem_update2, self.W_mem_update3
        ]
        self.b_mem_update = [self.b_mem_upd1, self.b_mem_upd2, self.b_mem_upd3]

        self.W_1 = nn_utils.normal_param(std=0.1,
                                         shape=(self.dim, 7 * self.dim + 0))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim, ))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1, ))
        logging.info(
            '==> building episodic memory module (fixed number of steps: %d)',
            self.memory_hops)
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            #m = printing.Print('mem')(memory[iter-1])
            current_episode = self.new_episode(memory[iter - 1])
            # Replace GRU with ReLU activation + MLP.
            c = T.concatenate([memory[iter - 1], current_episode], axis=0)
            cur_mem = T.dot(self.W_mem_update[iter - 1],
                            c) + self.b_mem_update[iter - 1].dimshuffle(
                                0, 'x')
            memory.append(T.nnet.relu(cur_mem))

        last_mem_raw = memory[-1].dimshuffle((1, 0))

        net = layers.InputLayer(shape=(self.batch_size * self.story_len,
                                       self.dim),
                                input_var=last_mem_raw)

        if self.batch_norm:
            net = layers.BatchNormLayer(incoming=net)
        if self.dropout > 0 and self.mode == 'train':
            net = layers.DropoutLayer(net, p=self.dropout)
        last_mem = layers.get_output(net).dimshuffle((1, 0))

        print "==> building answer module"

        answer_inp_var_shuffled = self.answer_inp_var.dimshuffle(1, 2, 0)
        # Sounds good. Now, we need to map last_mem to a new space.
        self.W_mem_emb = nn_utils.normal_param(std=0.1,
                                               shape=(self.dim, self.dim * 2))
        self.b_mem_emb = nn_utils.constant_param(value=0.0, shape=(self.dim, ))
        self.W_inp_emb = nn_utils.normal_param(std=0.1,
                                               shape=(self.dim,
                                                      self.vocab_size + 1))
        self.b_inp_emb = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        def _dot2(x, W, b):
            #return  T.tanh(T.dot(W, x) + b.dimshuffle(0,'x'))
            return T.dot(W, x) + b.dimshuffle(0, 'x')

        answer_inp_var_shuffled_emb, _ = theano.scan(
            fn=_dot2,
            sequences=answer_inp_var_shuffled,
            non_sequences=[self.W_inp_emb,
                           self.b_inp_emb])  # seq x dim x batch

        init_ans = T.concatenate([self.q_q, last_mem],
                                 axis=0)  # dim x (batch x self.story_len)

        mem_ans = T.dot(self.W_mem_emb, init_ans) + self.b_mem_emb.dimshuffle(
            0, 'x')  # dim x (batchsize x self.story_len)
        #mem_ans_dim = mem_ans.dimshuffle('x',0,1)
        mem_ans_rhp = T.reshape(mem_ans.dimshuffle(
            1, 0), (self.batch_size, self.story_len, mem_ans.shape[0]))
        mem_ans_dim = mem_ans_rhp.dimshuffle(1, 2, 0)
        answer_inp = answer_inp_var_shuffled_emb
        #answer_inp = T.concatenate([mem_ans_dim, answer_inp_var_shuffled_emb], axis = 0) #seq + 1 x dim x (batch-size x self.story+len)
        # Now, each answer got its input, our next step is to obtain the sequences.
        answer_inp_shu = answer_inp.dimshuffle(2, 0, 1)
        answer_inp_shu_rhp = T.reshape(answer_inp_shu, (self.batch_size, self.story_len, answer_inp_shu.shape[1],\
                answer_inp_shu.shape[2]))

        answer_inp = answer_inp_shu_rhp.dimshuffle(
            1, 2, 3, 0)  # story_len x seq + 1 x dim x batch_size

        self.W_a = nn_utils.normal_param(std=0.1,
                                         shape=(self.vocab_size + 1, self.dim))

        self.W_ans_res_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_ans_res_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_ans_res = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_ans_upd_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_ans_upd_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_ans_upd = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_ans_hid_in = nn_utils.normal_param(std=0.1,
                                                  shape=(self.dim, self.dim))
        self.W_ans_hid_hid = nn_utils.normal_param(std=0.1,
                                                   shape=(self.dim, self.dim))
        self.b_ans_hid = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        self.W_ans_map = nn_utils.normal_param(std=0.1,
                                               shape=(self.dim, self.dim * 2))
        self.b_ans_map = nn_utils.constant_param(value=0.0, shape=(self.dim, ))

        results = None
        r = None

        dummy = theano.shared(
            np.zeros((self.dim, self.batch_size), dtype=floatX))
        for i in range(self.story_len):
            answer_inp_i = answer_inp[i, :]  # seq + 1 x dim x batch_size
            mem_ans_dim_i = mem_ans_dim[i, :]  # dim x batch_size
            if i == 0:
                q_glb_inp = q_glb_last.dimshuffle('x', 1,
                                                  0)  #1 x dim x batch_size
                answer_inp_i = T.concatenate([q_glb_inp, answer_inp_i], axis=0)

                init_h = T.concatenate([dummy, mem_ans_dim_i], axis=0)
                init_h = T.dot(self.W_ans_map,
                               init_h) + self.b_ans_map.dimshuffle(0, 'x')
                init_h = T.tanh(init_h)
                r, _ = theano.scan(fn=self.answer_gru_step,
                                   sequences=answer_inp_i,
                                   truncate_gradient=self.truncate_gradient,
                                   outputs_info=[init_h])
                r = r[1:, :]  # get rid of the first glob one.
                results = r.dimshuffle('x', 0, 1, 2)
            else:
                prev_h = r[self.answer_idx[:, i], :, T.arange(self.batch_size)]
                h_ = T.concatenate([prev_h.dimshuffle(1, 0), mem_ans_dim_i],
                                   axis=0)
                h_ = T.dot(self.W_ans_map, h_) + self.b_ans_map.dimshuffle(
                    0, 'x')
                h_ = T.tanh(h_)

                r, _ = theano.scan(fn=self.answer_gru_step,
                                   sequences=answer_inp_i,
                                   truncate_gradient=self.truncate_gradient,
                                   outputs_info=[h_])
                results = T.concatenate([results, r.dimshuffle('x', 0, 1, 2)])
        ## results: story_len x seq+1 x dim x batch_size
        results = results.dimshuffle(3, 0, 1, 2)
        results = T.reshape(results, (self.batch_size * self.story_len,
                                      results.shape[2], results.shape[3]))
        results = results.dimshuffle(1, 2, 0)  # seq_len x dim x (batch x seq)

        # Assume there is a start token
        #print 'results', results.shape.eval({self.input_var: np.random.rand(2,5,196,512).astype('float32'),
        #    self.q_var: np.random.rand(2,5, 4096).astype('float32'),
        #    self.answer_idx: np.asarray([[1,1,1,1,1],[2,2,2,2,2]]).astype('int32'),
        #    self.answer_inp_var: np.random.rand(5, 18, 8001).astype('float32')})

        #results = results[1:-1,:,:] # get rid of the last token as well as the first one (image)
        #print results.shape.eval({self.input_var: np.random.rand(3,4,4096).astype('float32'),
        #    self.q_var: np.random.rand(3, 4096).astype('float32'),
        #    self.answer_inp_var: np.random.rand(3, 18, 8001).astype('float32')}, on_unused_input='ignore')

        # Now, we need to transform it to the probabilities.

        prob, _ = theano.scan(fn=lambda x, w: T.dot(w, x),
                              sequences=results,
                              non_sequences=self.W_a)
        #print 'prob', prob.shape.eval({self.input_var: np.random.rand(2,5,196,512).astype('float32'),
        #    self.q_var: np.random.rand(2,5, 4096).astype('float32'),
        #    self.answer_idx: np.asarray([[1,1,1,1,1],[2,2,2,2,2]]).astype('int32'),
        #    self.answer_inp_var: np.random.rand(5, 18, 8001).astype('float32')})

        #preds = prob[1:,:,:]
        #prob = prob[1:-1,:,:]
        preds = prob
        prob = prob[:-1, :, :]

        prob_shuffled = prob.dimshuffle(2, 0, 1)  # b * len * vocab
        preds_shuffled = preds.dimshuffle(2, 0, 1)

        logging.info("prob shape.")
        #print prob.shape.eval({self.input_var: np.random.rand(3,4,4096).astype('float32'),
        #    self.q_var: np.random.rand(3, 4096).astype('float32'),
        #    self.answer_inp_var: np.random.rand(3, 18, 8001).astype('float32')})

        n = prob_shuffled.shape[0] * prob_shuffled.shape[1]
        n_preds = preds_shuffled.shape[0] * preds_shuffled.shape[1]

        prob_rhp = T.reshape(prob_shuffled, (n, prob_shuffled.shape[2]))
        preds_rhp = T.reshape(preds_shuffled,
                              (n_preds, preds_shuffled.shape[2]))

        prob_sm = nn_utils.softmax_(prob_rhp)
        preds_sm = nn_utils.softmax_(preds_rhp)
        self.prediction = prob_sm  # this one is for the training.

        #print 'prob_sm', prob_sm.shape.eval({prob: np.random.rand(19,8897,3).astype('float32')})
        #print 'lbl', loss_vec.shape.eval({prob: np.random.rand(19,8897,3).astype('float32')})
        # This one is for the beamsearch.
        self.pred = T.reshape(
            preds_sm, (preds_shuffled.shape[0], preds_shuffled.shape[1],
                       preds_shuffled.shape[2]))

        mask = T.reshape(self.answer_mask, (n, ))
        lbl = T.reshape(self.answer_var, (n, ))

        self.params = [
            self.W_inp_emb_in,
            self.b_inp_emb_in,
            self.W_q_emb_in,
            self.b_q_emb_in,
            #self.W_glb_att_1, self.W_glb_att_2, self.b_glb_att_1, self.b_glb_att_2,
            self.W_qf_res_in,
            self.W_qf_res_hid,
            self.b_qf_res,
            self.W_qf_upd_in,
            self.W_qf_upd_hid,
            self.b_qf_upd,
            self.W_qf_hid_in,
            self.W_qf_hid_hid,
            self.b_qf_hid,
            self.W_mem_emb,
            self.W_inp_emb,
            self.b_mem_emb,
            self.b_inp_emb,
            self.W_mem_res_in,
            self.W_mem_res_hid,
            self.b_mem_res,
            self.W_mem_upd_in,
            self.W_mem_upd_hid,
            self.b_mem_upd,
            self.W_mem_hid_in,
            self.W_mem_hid_hid,
            self.b_mem_hid,  #self.W_b
            #self.W_mem_emb, self.W_inp_emb,self.b_mem_emb, self.b_inp_emb,
            self.W_1,
            self.W_2,
            self.b_1,
            self.b_2,
            self.W_a,
            self.W_ans_res_in,
            self.W_ans_res_hid,
            self.b_ans_res,
            self.W_ans_upd_in,
            self.W_ans_upd_hid,
            self.b_ans_upd,
            self.W_ans_hid_in,
            self.W_ans_hid_hid,
            self.b_ans_hid,
            self.W_ans_map,
            self.b_ans_map,
        ]
        self.params += self.W_mem_update
        self.params += self.b_mem_update

        print "==> building loss layer and computing updates"
        reward_prob = prob_sm[T.arange(n), lbl]
        reward_prob = T.reshape(
            reward_prob, (prob_shuffled.shape[0], prob_shuffled.shape[1]))
        #reward_prob = printing.Print('mean_r')(reward_prob)

        loss_vec = T.nnet.categorical_crossentropy(prob_sm, lbl)
        #loss_vec = T.nnet.categorical_crossentropy(prob_sm, T.flatten(self.answer_var))
        #print 'loss_vec', loss_vec.shape.eval({prob_sm: np.random.rand(39,8900).astype('float32'),
        #    lbl: np.random.rand(39,).astype('int32')})

        self.loss_ce = (mask * loss_vec).sum() / mask.sum()
        print 'loss_ce', self.loss_ce.eval({
            prob_sm:
            np.random.rand(39, 8900).astype('float32'),
            lbl:
            np.random.rand(39, ).astype('int32'),
            mask:
            np.random.rand(39, ).astype('float32')
        })

        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0

        self.loss = self.loss_ce + self.loss_l2
        grads = T.grad(self.loss, wrt=self.params, disconnected_inputs='raise')

        updates = lasagne.updates.adadelta(grads,
                                           self.params,
                                           learning_rate=self.learning_rate)

        if self.mode == 'train':
            logging.info("compiling train_fn")
            self.train_fn = theano.function(
                inputs=[
                    self.input_var, self.q_var, self.answer_var,
                    self.answer_mask, self.answer_inp_var, self.answer_idx
                ],
                outputs=[self.prediction, self.loss],
                updates=updates)

        logging.info("compiling test_fn")
        self.test_fn = theano.function(inputs=[
            self.input_var, self.q_var, self.answer_var, self.answer_mask,
            self.answer_inp_var, self.answer_idx
        ],
                                       outputs=[self.prediction, self.loss])

        logging.info("compiling pred_fn")
        self.pred_fn = theano.function(inputs=[
            self.input_var, self.q_var, self.answer_inp_var, self.answer_idx
        ],
                                       outputs=[self.pred])
    def __init__(self, data_dir, word2vec, word_vector_size, dim, cnn_dim, story_len,
                patches,cnn_dim_fc,truncate_gradient, learning_rate,
                mode, answer_module, memory_hops, batch_size, l2,
                normalize_attention, batch_norm, dropout, **kwargs):
        
        print "==> not used params in DMN class:", kwargs.keys()

        self.data_dir = data_dir
        self.truncate_gradient = truncate_gradient
        self.learning_rate = learning_rate

        self.trng = RandomStreams(1234)
        
        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.cnn_dim = cnn_dim
        self.cnn_dim_fc = cnn_dim_fc
        self.story_len = story_len
        self.patches = patches
        self.mode = mode
        self.answer_module = answer_module
        self.memory_hops = memory_hops
        self.batch_size = batch_size
        self.l2 = l2
        self.normalize_attention = normalize_attention
        self.batch_norm = batch_norm
        self.dropout = dropout

        self.vocab, self.ivocab = self._load_vocab(self.data_dir)

        self.train_story = None
        self.test_story = None
        self.train_dict_story, self.train_lmdb_env_fc, self.train_lmdb_env_conv = self._process_input_sind_lmdb(self.data_dir, 'train')
        self.test_dict_story, self.test_lmdb_env_fc, self.test_lmdb_env_conv = self._process_input_sind_lmdb(self.data_dir, 'val')

        self.train_story = self.train_dict_story.keys()
        self.test_story = self.test_dict_story.keys()
        self.vocab_size = len(self.vocab)
        self.alpha_entropy_c = 0.02 # for hard attention.
        
        # This is the local patch of each image.
        self.input_var = T.tensor4('input_var') # (batch_size, seq_len, patches, cnn_dim)
        self.q_var = T.tensor3('q_var') # Now, it's a batch * story_len * image_sieze.
        self.answer_var = T.ivector('answer_var') # answer of example in minibatch
        self.answer_mask = T.matrix('answer_mask')
        self.answer_idx = T.imatrix('answer_idx') # batch x seq
        self.answer_inp_var = T.tensor3('answer_inp_var') # answer of example in minibatch
        
        print "==> building input module"
        # It's very simple now, the input module just need to map from cnn_dim to dim.
        logging.info('self.cnn_dim = %d', self.cnn_dim)
        logging.info('self.cnn_dim_fc = %d', self.cnn_dim_fc)
        logging.info('self.dim = %d', self.dim)
        self.W_q_emb_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.cnn_dim_fc))
        self.b_q_emb_in = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        logging.info('Building the glob attention model')
        self.W_glb_att_1 = nn_utils.normal_param(std = 0.1, shape = (self.dim, 2 * self.dim))
        self.W_glb_att_2 = nn_utils.normal_param(std = 0.1, shape = (1, self.dim))
        self.b_glb_att_1 = nn_utils.constant_param(value = 0.0, shape = (self.dim,))
        self.b_glb_att_2 = nn_utils.constant_param(value = 0.0, shape = (1,))

        q_var_shuffled = self.q_var.dimshuffle(1,2,0) # seq x cnn x batch.

        def _dot(x, W, b):
            return T.tanh( T.dot(W, x) + b.dimshuffle(0, 'x'))

        q_var_shuffled_emb,_ = theano.scan(fn = _dot, sequences= q_var_shuffled, non_sequences = [self.W_q_emb_in, self.b_q_emb_in])
        print 'q_var_shuffled_emb', q_var_shuffled_emb.shape.eval({self.q_var:np.random.rand(2,5,4096).astype('float32')})
        q_var_emb =  q_var_shuffled_emb.dimshuffle(2,0,1) # batch x seq x emb_size
        q_var_emb_ext = q_var_emb.dimshuffle(0,'x',1,2)
        q_var_emb_ext = T.repeat(q_var_emb_ext, q_var_emb.shape[1],1) # batch x seq x seq x emb_size
        q_var_emb_rhp = T.reshape( q_var_emb, (q_var_emb.shape[0] * q_var_emb.shape[1], q_var_emb.shape[2]))
        q_var_emb_ext_rhp = T.reshape(q_var_emb_ext, (q_var_emb_ext.shape[0] * q_var_emb_ext.shape[1],q_var_emb_ext.shape[2], q_var_emb_ext.shape[3]))
        q_var_emb_ext_rhp = q_var_emb_ext_rhp.dimshuffle(0,2,1)
        q_idx = T.arange(self.story_len).dimshuffle('x',0)
        q_idx = T.repeat(q_idx,self.batch_size, axis = 0)
        q_idx = T.reshape(q_idx, (q_idx.shape[0]* q_idx.shape[1],))
        print q_idx.eval()
        print 'q_var_emb_rhp.shape', q_var_emb_rhp.shape.eval({self.q_var:np.random.rand(3,5,4096).astype('float32')})
        print 'q_var_emb_ext_rhp.shape', q_var_emb_ext_rhp.shape.eval({self.q_var:np.random.rand(3,5,4096).astype('float32')})

        #att_alpha,_ = theano.scan(fn = self.new_attention_step_glob, sequences = [q_var_emb_rhp, q_var_emb_ext_rhp, q_idx] )
        alpha,_ = theano.scan(fn = self.new_attention_step_glob, sequences = [q_var_emb_rhp, q_var_emb_ext_rhp, q_idx] )

        att_alpha = alpha[1]
        att_alpha_a = alpha[0]
        #print results.shape.eval({self.input_var: np.random.rand(3,4,4096).astype('float32'),
        print att_alpha.shape.eval({self.q_var:np.random.rand(3,5,4096).astype('float32')})

        # att_alpha: (batch x seq) x seq)
        
        self.W_inp_emb_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.cnn_dim))
        self.b_inp_emb_in = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        inp_rhp = T.reshape(self.input_var, (self.batch_size* self.story_len* self.patches, self.cnn_dim))
        inp_rhp_dimshuffled = inp_rhp.dimshuffle(1,0)
        inp_rhp_emb = T.dot(self.W_inp_emb_in, inp_rhp_dimshuffled) + self.b_inp_emb_in.dimshuffle(0,'x')
        inp_rhp_emb_dimshuffled = inp_rhp_emb.dimshuffle(1,0)
        inp_emb_raw = T.reshape(inp_rhp_emb_dimshuffled, (self.batch_size, self.story_len, self.patches, self.cnn_dim))
        inp_emb = T.tanh(inp_emb_raw) # Just follow the paper DMN for visual and textual QA.
        #inp_emb = inp_emb.dimshuffle(0,'x', 1,2)
        #inp_emb = T.repeat(inp_emb, self.story_len, 1)

        #print inp_emb.shape.eval({self.input_var:np.random.rand(3,5,196, 4096).astype('float32')})

        att_alpha_sample = self.trng.multinomial(pvals = att_alpha, dtype=theano.config.floatX)
        att_mask = att_alpha_sample.argmax(1)
        print 'att_mask.shape', att_mask.shape.eval({self.q_var:np.random.rand(2,5,4096).astype('float32')})
        print 'att_mask', att_mask.eval({self.q_var:np.random.rand(2,5,4096).astype('float32')})

        # No time to fix the hard attention, now we use the soft attention.
        idx_t = T.repeat(T.arange(self.input_var.shape[0]), self.input_var.shape[1])
        print 'idx_t', idx_t.eval({self.input_var:np.random.rand(2,5,196,512).astype('float32')})
        att_input =  inp_emb[idx_t, att_mask,:,:] # (batch x seq) x batches x emb_size
        att_input = T.reshape(att_input, (self.batch_size, self.story_len, self.patches, self.dim))
        print 'att_input', att_input.shape.eval({self.input_var:np.random.rand(2,5,196,512).astype('float32'),self.q_var:np.random.rand(2,5,4096).astype('float32')})
        
        # Now, it's the same size with the input_var, but we have only one image for each one of input.
        # Now, we can use the rnn on these local imgs to learn the 
        # Now, we use a bi-directional GRU to produce the input.
        # Forward GRU.


        self.inp_c = T.reshape(att_input, (att_input.shape[0] * att_input.shape[1], att_input.shape[2], att_input.shape[3]))
        self.inp_c = self.inp_c.dimshuffle(1,2,0)


        #print 'inp_c', self.inp_c.shape.eval({att_input:np.random.rand(2,5,196,512).astype('float32')})
        print "==> building question module"
        self.W_qf_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_qf_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_qf_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_qf_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_qf_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_qf_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_qf_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_qf_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_qf_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        inp_dummy = theano.shared(np.zeros((self.dim, self.batch_size), dtype = floatX))
        #print 'q_var_shuffled_emb', q_var_shuffled_emb.shape.eval({self.q_var:np.random.rand(2,5,4096).astype('float32')})
        q_glb,_ = theano.scan(fn = self.q_gru_step_forward, 
                                    sequences = q_var_shuffled_emb,
                                    outputs_info = [T.zeros_like(inp_dummy)])
        q_glb_shuffled = q_glb.dimshuffle(2,0,1) # batch_size * seq_len * dim
        #print 'q_glb_shuffled', q_glb_shuffled.shape.eval({self.q_var:np.random.rand(2,5,4096).astype('float32')})
        q_glb_last = q_glb_shuffled[:,-1,:] # batch_size * dim
        #print 'q_glb_last', q_glb_last.shape.eval({self.q_var:np.random.rand(2,5,4096).astype('float32')})

        q_net = layers.InputLayer(shape=(self.batch_size*self.story_len, self.dim), input_var=q_var_emb_rhp)
        if self.batch_norm:
            q_net = layers.BatchNormLayer(incoming=q_net)
        if self.dropout > 0 and self.mode == 'train':
            q_net = layers.DropoutLayer(q_net, p=self.dropout)
        self.q_q = layers.get_output(q_net).dimshuffle(1,0)

        print "==> creating parameters for memory module"
        self.W_mem_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_mem_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_mem_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_mem_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_mem_update1 = nn_utils.normal_param(std=0.1, shape=(self.dim , self.dim* 2))
        self.b_mem_upd1 = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        self.W_mem_update2 = nn_utils.normal_param(std=0.1, shape=(self.dim,self.dim*2))
        self.b_mem_upd2 = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        self.W_mem_update3 = nn_utils.normal_param(std=0.1, shape=(self.dim , self.dim*2))
        self.b_mem_upd3 = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        self.W_mem_update = [self.W_mem_update1,self.W_mem_update2,self.W_mem_update3]
        self.b_mem_update = [self.b_mem_upd1,self.b_mem_upd2, self.b_mem_upd3]
        
        self.W_1 = nn_utils.normal_param(std=0.1, shape=(self.dim, 7 * self.dim + 0))
        self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
        self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        self.b_2 = nn_utils.constant_param(value=0.0, shape=(1,))
        print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
        memory = [self.q_q.copy()]
        for iter in range(1, self.memory_hops + 1):
            #m = printing.Print('mem')(memory[iter-1])
            current_episode = self.new_episode(memory[iter - 1])
            # Replace GRU with ReLU activation + MLP.
            c = T.concatenate([memory[iter - 1], current_episode], axis = 0)
            cur_mem = T.dot(self.W_mem_update[iter-1], c) + self.b_mem_update[iter-1].dimshuffle(0,'x')
            memory.append(T.nnet.relu(cur_mem))
        
        last_mem_raw = memory[-1].dimshuffle((1, 0))
        
        net = layers.InputLayer(shape=(self.batch_size * self.story_len, self.dim), input_var=last_mem_raw)

        if self.batch_norm:
            net = layers.BatchNormLayer(incoming=net)
        if self.dropout > 0 and self.mode == 'train':
            net = layers.DropoutLayer(net, p=self.dropout)
        last_mem = layers.get_output(net).dimshuffle((1, 0))

        print "==> building answer module"

        answer_inp_var_shuffled = self.answer_inp_var.dimshuffle(1,2,0)
        # Sounds good. Now, we need to map last_mem to a new space. 
        self.W_mem_emb = nn_utils.normal_param(std = 0.1, shape = (self.dim, self.dim * 2))
        self.b_mem_emb = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        self.W_inp_emb = nn_utils.normal_param(std = 0.1, shape = (self.dim, self.vocab_size + 1))
        self.b_inp_emb = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        def _dot2(x, W, b):
            #return  T.tanh(T.dot(W, x) + b.dimshuffle(0,'x'))
            return  T.dot(W, x) + b.dimshuffle(0,'x')

        answer_inp_var_shuffled_emb,_ = theano.scan(fn = _dot2, sequences = answer_inp_var_shuffled,
                non_sequences = [self.W_inp_emb,self.b_inp_emb] ) # seq x dim x batch
        
        #print 'answer_inp_var_shuffled_emb', answer_inp_var_shuffled_emb.shape.eval({self.answer_inp_var:np.random.rand(2,5,8900).astype('float32')})
        # dim x batch_size * 5
        q_glb_dim = q_glb_last.dimshuffle(0,'x', 1) # batch_size * 1 * dim
        #print 'q_glb_dim', q_glb_dim.shape.eval({self.q_var:np.random.rand(2,5,4096).astype('float32')})
        
        q_glb_repmat = T.repeat(q_glb_dim, self.story_len, 1) # batch_size * len * dim
        #print 'q_glb_repmat', q_glb_repmat.shape.eval({self.q_var:np.random.rand(2,5,4096).astype('float32')})
        q_glb_rhp = T.reshape(q_glb_repmat, (q_glb_repmat.shape[0] * q_glb_repmat.shape[1], q_glb_repmat.shape[2]))
        #print 'q_glb_rhp', q_glb_rhp.shape.eval({q_glb_last:np.random.rand(2,512).astype('float32')})

        init_ans = T.concatenate([self.q_q, last_mem], axis = 0)
        #print 'init_ans', init_ans.shape.eval({self.q_var:np.random.rand(2,5,4096).astype('float32'), self.input_var:np.random.rand(2,5,196, 512).astype('float32')})

        mem_ans = T.dot(self.W_mem_emb, init_ans) + self.b_mem_emb.dimshuffle(0,'x') # dim x batchsize
        mem_ans_dim = mem_ans.dimshuffle('x',0,1)
        answer_inp = T.concatenate([mem_ans_dim, answer_inp_var_shuffled_emb], axis = 0)

        q_glb_rhp = q_glb_rhp.dimshuffle(1,0)
        q_glb_rhp = q_glb_rhp.dimshuffle('x', 0, 1)
        q_glb_step = T.repeat(q_glb_rhp, answer_inp.shape[0], 0)

        #mem_ans = T.tanh(T.dot(self.W_mem_emb, init_ans) + self.b_mem_emb.dimshuffle(0,'x')) # dim x batchsize.
        # seq + 1 x dim x batch 
        answer_inp = T.concatenate([answer_inp, q_glb_step], axis = 1)
        dummy = theano.shared(np.zeros((self.dim, self.batch_size * self.story_len), dtype=floatX))

        self.W_a = nn_utils.normal_param(std=0.1, shape=(self.vocab_size + 1, self.dim))
        
        self.W_ans_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim * 2))
        self.W_ans_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_ans_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_ans_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim * 2))
        self.W_ans_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_ans_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_ans_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim * 2))
        self.W_ans_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_ans_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        results, _ = theano.scan(fn = self.answer_gru_step,
            sequences = answer_inp,
            outputs_info = [ dummy ])

        #results = None
        #r = None
        #for i in range(self.story_len):
        #    answer_inp_i = answer_inp[i,:]
        #    
        #    if i == 0:
        #        # results: seq + 1 x dim x batch_size
        #        r, _ = theano.scan(fn = self.answer_gru_step,
        #            sequences = answer_inp_i,
        #            truncate_gradient = self.truncate_gradient,
        #            outputs_info = [ dummy ])
        #        #print 'r', r.shape.eval({answer_inp_i:np.random.rand(23,512,2).astype('float32')})
        #        results = r.dimshuffle('x', 0, 1,2)
        #    else:
        #        prev_h = r[self.answer_idx[:,i],:,T.arange(self.batch_size)]
        #        #print 'prev_h', prev_h.shape.eval({answer_inp_i:np.random.rand(23,512,2).astype('float32'), self.answer_idx: np.asarray([[1,1,1,1,1],[2,2,2,2,2]]).astype('int32')},on_unused_input='warn' )
        #        #print 'prev_h', prev_h.shape.eval({r:np.random.rand(23,512,2).astype('float32'), self.answer_idx: np.asarray([[1,1,1,1,1],[2,2,2,2,2]]).astype('int32')})


        #        r,_ = theano.scan(fn = self.answer_gru_step,
        #                sequences = answer_inp_i,
        #                truncate_gradient = self.truncate_gradient,
        #                outputs_info = [ prev_h.dimshuffle(1,0) ])
        #        results = T.concatenate([results, r.dimshuffle('x', 0, 1, 2)])
        ## results: story_len x seq+1 x dim x batch_size
        #results = results.dimshuffle(3,0,1,2)
        #results = T.reshape(results, (self.batch_size * self.story_len, results.shape[2], results.shape[3]))
        #results = results.dimshuffle(1,2,0) # seq_len x dim x (batch x seq)

        # Assume there is a start token 
        #print 'results', results.shape.eval({self.input_var: np.random.rand(2,5,196,512).astype('float32'),
        #    self.q_var: np.random.rand(2,5, 4096).astype('float32'), 
        #    self.answer_idx: np.asarray([[1,1,1,1,1],[2,2,2,2,2]]).astype('int32'),
        #    self.answer_inp_var: np.random.rand(5, 18, 8001).astype('float32')})

        #results = results[1:-1,:,:] # get rid of the last token as well as the first one (image)
        #print results.shape.eval({self.input_var: np.random.rand(3,4,4096).astype('float32'),
        #    self.q_var: np.random.rand(3, 4096).astype('float32'), 
        #    self.answer_inp_var: np.random.rand(3, 18, 8001).astype('float32')}, on_unused_input='ignore')
            
        # Now, we need to transform it to the probabilities.

        prob,_ = theano.scan(fn = lambda x, w: T.dot(w, x), sequences = results, non_sequences = self.W_a )
        #print 'prob', prob.shape.eval({self.input_var: np.random.rand(2,5,196,512).astype('float32'),
        #    self.q_var: np.random.rand(2,5, 4096).astype('float32'), 
        #    self.answer_idx: np.asarray([[1,1,1,1,1],[2,2,2,2,2]]).astype('int32'),
        #    self.answer_inp_var: np.random.rand(5, 18, 8001).astype('float32')})


        preds = prob[1:,:,:]
        prob = prob[1:-1,:,:]

        prob_shuffled = prob.dimshuffle(2,0,1) # b * len * vocab
        preds_shuffled = preds.dimshuffle(2,0,1)


        logging.info("prob shape.")
        #print prob.shape.eval({self.input_var: np.random.rand(3,4,4096).astype('float32'),
        #    self.q_var: np.random.rand(3, 4096).astype('float32'), 
        #    self.answer_inp_var: np.random.rand(3, 18, 8001).astype('float32')})

        n = prob_shuffled.shape[0] * prob_shuffled.shape[1]
        n_preds = preds_shuffled.shape[0] * preds_shuffled.shape[1]

        prob_rhp = T.reshape(prob_shuffled, (n, prob_shuffled.shape[2]))
        preds_rhp = T.reshape(preds_shuffled, (n_preds, preds_shuffled.shape[2]))

        prob_sm = nn_utils.softmax_(prob_rhp)
        preds_sm = nn_utils.softmax_(preds_rhp)
        self.prediction = prob_sm # this one is for the training.

        #print 'prob_sm', prob_sm.shape.eval({prob: np.random.rand(19,8897,3).astype('float32')})
        #print 'lbl', loss_vec.shape.eval({prob: np.random.rand(19,8897,3).astype('float32')})
        # This one is for the beamsearch.
        self.pred = T.reshape(preds_sm, (preds_shuffled.shape[0], preds_shuffled.shape[1], preds_shuffled.shape[2]))

        mask =  T.reshape(self.answer_mask, (n,))
        lbl = T.reshape(self.answer_var, (n,))

        self.params = [self.W_inp_emb_in, self.b_inp_emb_in, 
                self.W_q_emb_in, self.b_q_emb_in,
                self.W_glb_att_1, self.W_glb_att_2, self.b_glb_att_1, self.b_glb_att_2,
                self.W_qf_res_in, self.W_qf_res_hid, self.b_qf_res,
                self.W_qf_upd_in, self.W_qf_upd_hid, self.b_qf_upd,
                self.W_qf_hid_in, self.W_qf_hid_hid, self.b_qf_hid,
                self.W_mem_emb, self.W_inp_emb,self.b_mem_emb, self.b_inp_emb,
                self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res, 
                self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
                self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid, #self.W_b
                #self.W_mem_emb, self.W_inp_emb,self.b_mem_emb, self.b_inp_emb,
                self.W_1, self.W_2, self.b_1, self.b_2, self.W_a,
                self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res, 
                self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid,
                ]
        self.params += self.W_mem_update
        self.params += self.b_mem_update
                              
                              
        print "==> building loss layer and computing updates"
        reward_prob = prob_sm[T.arange(n), lbl]
        reward_prob = T.reshape(reward_prob, (prob_shuffled.shape[0], prob_shuffled.shape[1]))
        #reward_prob = printing.Print('mean_r')(reward_prob)

        loss_vec = T.nnet.categorical_crossentropy(prob_sm, lbl)
        #loss_vec = T.nnet.categorical_crossentropy(prob_sm, T.flatten(self.answer_var))
        #print 'loss_vec', loss_vec.shape.eval({prob_sm: np.random.rand(39,8900).astype('float32'),
        #    lbl: np.random.rand(39,).astype('int32')})

        self.loss_ce = (mask * loss_vec ).sum() / mask.sum() 
        print 'loss_ce', self.loss_ce.eval({prob_sm: np.random.rand(39,8900).astype('float32'),
            lbl: np.random.rand(39,).astype('int32'), mask: np.random.rand(39,).astype('float32')})

        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0
        
        self.loss = self.loss_ce + self.loss_l2
        self.baseline_time = theano.shared(np.float32(0.), name='baseline_time')
        alpha_entropy_c = theano.shared(np.float32(self.alpha_entropy_c), name='alpha_entropy_c')
        #mean_r = ( mask * reward_prob).sum() / mask.sum() # or just fixed it as 1.
        #mean_r = 1
        mean_r = (self.answer_mask * reward_prob).sum(1) / self.answer_mask.sum(1) # or just fixed it as 1.
        mean_r = mean_r[0,None]
        grads = T.grad(self.loss, wrt=self.params,
                     disconnected_inputs='raise',
                     known_grads={att_alpha_a:(mean_r - self.baseline_time)*
                     (att_alpha_sample/(att_alpha_a + 1e-10)) + alpha_entropy_c*(T.log(att_alpha_a + 1e-10) + 1)})

            
        updates = lasagne.updates.adadelta(grads, self.params, learning_rate = self.learning_rate)
        updates[self.baseline_time] =  self.baseline_time * 0.9 + 0.1 * mean_r.mean()
        #updates = lasagne.updates.momentum(self.loss, self.params, learning_rate=0.001)
        
        if self.mode == 'train':
            logging.info("compiling train_fn")
            self.train_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.answer_mask, self.answer_inp_var], 
                                            outputs=[self.prediction, self.loss],
                                            updates=updates)
        
        logging.info("compiling test_fn")
        self.test_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.answer_mask, self.answer_inp_var],
                                       outputs=[self.prediction, self.loss])
        
        logging.info("compiling pred_fn")
        self.pred_fn= theano.function(inputs=[self.input_var, self.q_var, self.answer_inp_var],
                                       outputs=[self.pred])
Exemple #6
0
    def __init__(self, data_dir, word2vec, word_vector_size, dim,
                mode, answer_module, memory_hops, batch_size, l2,
                normalize_attention, batch_norm, dropout, **kwargs):
        
        print "==> not used params in DMN class:", kwargs.keys()

        self.data_dir = data_dir
        
        self.word2vec = word2vec
        self.word_vector_size = word_vector_size
        self.dim = dim
        self.mode = mode
        self.answer_module = answer_module
        self.memory_hops = memory_hops
        self.batch_size = batch_size
        self.l2 = l2
        self.normalize_attention = normalize_attention
        self.batch_norm = batch_norm
        self.dropout = dropout

        self.vocab, self.ivocab = self._load_vocab(self.data_dir)

        self.train_story = None
        self.test_story = None
        self.train_dict_story, self.train_features, self.train_fns_dict, self.train_num_imgs = self._process_input_sind(self.data_dir, 'train')
        self.test_dict_story, self.test_features, self.test_fns_dict, self.test_num_imgs = self._process_input_sind(self.data_dir, 'val')

        self.train_story = self.train_dict_story.keys()
        self.test_story = self.test_dict_story.keys()
        self.vocab_size = len(self.vocab)
        
        self.q_var = T.matrix('q_var') # Now, it's a batch * image_sieze.
        self.answer_var = T.imatrix('answer_var') # answer of example in minibatch
        self.answer_mask = T.matrix('answer_mask')
        self.answer_inp_var = T.tensor3('answer_inp_var') # answer of example in minibatch
        
        print "==> building question module"
        # Now, share the parameter with the input module.
        q_var_shuffled = self.q_var.dimshuffle(1,0)
        self.W_inp_emb_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.cnn_dim))
        self.b_inp_emb_in = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        q_hist = T.dot(self.W_inp_emb_in, q_var_shuffled) + self.b_inp_emb_in.dimshuffle(0,'x')

        q_hist_shuffled = q_hist.dimshuffle(1,0)

        if self.batch_norm:
            logging.info("Using batch normalization.")
        q_net = layers.InputLayer(shape=(self.batch_size, self.dim), input_var=q_hist_shuffled)
        if self.batch_norm:
            q_net = layers.BatchNormLayer(incoming=q_net)
        if self.dropout > 0 and self.mode == 'train':
            q_net = layers.DropoutLayer(q_net, p=self.dropout)
        #last_mem = layers.get_output(q_net).dimshuffle((1, 0))
        self.q_q = layers.get_output(q_net).dimshuffle(1,0)

        print "==> building answer module"

        answer_inp_var_shuffled = self.answer_inp_var.dimshuffle(1,2,0)
        #self.W_mem_emb = nn_utils.normal_param(std = 0.1, shape = (self.dim, self.dim))
        self.W_inp_emb = nn_utils.normal_param(std = 0.1, shape = (self.dim, self.vocab_size + 1))

        def _dot2(x, W):
            return  T.dot(W, x)

        answer_inp_var_shuffled_emb,_ = theano.scan(fn = _dot2, sequences = answer_inp_var_shuffled,
                non_sequences = self.W_inp_emb ) # seq x dim x batch
        

        mem_ans = self.q_q
        mem_ans_dim = mem_ans.dimshuffle('x',0,1)

        answer_inp = T.concatenate([mem_ans_dim, answer_inp_var_shuffled_emb], axis = 0)
        
        dummy = theano.shared(np.zeros((self.dim, self.batch_size), dtype=floatX))

        self.W_a = nn_utils.normal_param(std=0.1, shape=(self.vocab_size + 1, self.dim))
        
        self.W_ans_res_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_ans_res_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_ans_res = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_ans_upd_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_ans_upd_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_ans_upd = nn_utils.constant_param(value=0.0, shape=(self.dim,))
        
        self.W_ans_hid_in = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.W_ans_hid_hid = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
        self.b_ans_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))

        logging.info('answer_inp size')

        #last_mem = printing.Print('prob_sm')(last_mem)
        results, _ = theano.scan(fn = self.answer_gru_step,
                sequences = answer_inp,
                outputs_info = [ dummy ])

        prob,_ = theano.scan(fn = lambda x, w: T.dot(w, x), sequences = results, non_sequences = self.W_a )
        preds = prob[1:,:,:]
        prob = prob[1:-1,:,:]

        prob_shuffled = prob.dimshuffle(2,0,1) # b * len * vocab
        preds_shuffled = preds.dimshuffle(2,0,1)


        logging.info("prob shape.")
        #print prob.shape.eval({self.input_var: np.random.rand(10,4,4096).astype('float32'),
        #    self.q_var: np.random.rand(10, 4096).astype('float32'), 
        #    self.answer_inp_var: np.random.rand(10, 18, 8001).astype('float32')})

        n = prob_shuffled.shape[0] * prob_shuffled.shape[1]
        n_preds = preds_shuffled.shape[0] * preds_shuffled.shape[1]

        prob_rhp = T.reshape(prob_shuffled, (n, prob_shuffled.shape[2]))
        preds_rhp = T.reshape(preds_shuffled, (n_preds, preds_shuffled.shape[2]))

        prob_sm = nn_utils.softmax_(prob_rhp)
        preds_sm = nn_utils.softmax_(preds_rhp)
        self.prediction = prob_sm # this one is for the training.

        # This one is for the beamsearch.
        self.pred = T.reshape(preds_sm, (preds_shuffled.shape[0], preds_shuffled.shape[1], preds_shuffled.shape[2]))

        mask =  T.reshape(self.answer_mask, (n,))
        lbl = T.reshape(self.answer_var, (n,))

        self.params = [self.W_a,self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res, 
                              self.W_ans_upd_in, self.W_ans_upd_hid, self.b_ans_upd,
                              self.W_ans_hid_in, self.W_ans_hid_hid, self.b_ans_hid,
                              self.W_inp_emb_in, self.b_inp_emb_in,
                              self.W_inp_emb]
                              
        print "==> building loss layer and computing updates"
        loss_vec = T.nnet.categorical_crossentropy(prob_sm, lbl)
        self.loss_ce = (mask * loss_vec ).sum() / mask.sum() 

        #self.loss_ce = T.nnet.categorical_crossentropy(results_rhp, lbl)
            
        if self.l2 > 0:
            self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
        else:
            self.loss_l2 = 0
        
        self.loss = self.loss_ce + self.loss_l2

        grad = T.grad(self.loss, self.params)
        #scaled_grad = lasagne.updates.norm_constraint(grad, max_norm = 1e4)
        updates = lasagne.updates.adadelta(self.loss, self.params, learning_rate = 0.01)
        #updates = lasagne.updates.momentum(self.loss, self.params, learning_rate=0.001)
        
        if self.mode == 'train':
            print "==> compiling train_fn"
            self.train_fn = theano.function(inputs=[self.q_var, self.answer_var, self.answer_mask, self.answer_inp_var], 
                                            outputs=[self.prediction, self.loss],
                                            updates=updates)
        
        print "==> compiling test_fn"
        self.test_fn = theano.function(inputs=[self.q_var, self.answer_var, self.answer_mask, self.answer_inp_var],
                                       outputs=[self.prediction, self.loss])
        
    
        print "==> compiling pred_fn"
        self.pred_fn= theano.function(inputs=[self.q_var, self.answer_inp_var],
                                       outputs=[self.pred])