Ejemplo n.º 1
0
    def __init__(self,
                 vid_encoder,
                 qns_encoder,
                 max_len_v,
                 max_len_q,
                 device,
                 input_drop_p=0.2):
        """
        Heterogeneous memory enhanced multimodal attention model for video question answering (CVPR19)
        :param vid_encoder:
        :param qns_encoder:
        :param ans_decoder:
        :param device:
        """
        super(HME, self).__init__()
        self.vid_encoder = vid_encoder
        self.qns_encoder = qns_encoder

        dim = qns_encoder.dim_hidden

        self.temp_att_a = TempAttention(dim * 2, dim * 2, hidden_dim=256)
        self.temp_att_m = TempAttention(dim * 2, dim * 2, hidden_dim=256)
        self.mrm_vid = MemoryRamTwoStreamModule(dim, dim, max_len_v, device)
        self.mrm_txt = MemoryRamModule(dim, dim, max_len_q, device)

        self.mm_module_v1 = MMModule(dim, input_drop_p, device)

        self.linear_vid = nn.Linear(dim * 2, dim)
        self.linear_qns = nn.Linear(dim * 2, dim)
        self.linear_mem = nn.Linear(dim * 2, dim)
        self.vq2word_hme = nn.Linear(dim * 3, 1)
        self.device = device
Ejemplo n.º 2
0
    def __init__(self, task, feat_channel, feat_dim, text_embed_size, hidden_size, vocab_size, num_layers, word_matrix,
                    answer_vocab_size=None, max_len=20, dropout=0.2, mm_version=1, useSpatial=False, useNaive=False, iter_num=3):
        """Set the hyper-parameters and build the layers."""
        super(AttentionTwoStream, self).__init__()
        
        self.task = task
        
        
        # text input size
        self.text_embed_size = text_embed_size # should be 300
        
        # video input size
        self.feat_channel = feat_channel
        self.feat_dim = feat_dim # should be 7
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.useNaive = useNaive
        self.useSpatial = useSpatial
        self.mm_version = mm_version
        
        self.TpAtt_a = TemporalAttentionModule(hidden_size*2, hidden_size)
        self.TpAtt_m = TemporalAttentionModule(hidden_size*2, hidden_size)

        if useSpatial:
            self.SpAtt = SpatialAttentionModule(feat_channel, feat_dim, hidden_size)
        else:
            self.video_encoder = nn.Linear(feat_channel, hidden_size) 

        self.drop_keep_prob_final_att_vec = nn.Dropout(dropout)
        self.embed = nn.Embedding(vocab_size, text_embed_size)
                    
        self.lstm_text_1 = nn.LSTMCell(text_embed_size, hidden_size)
        self.lstm_text_2 = nn.LSTMCell(hidden_size, hidden_size)

        self.lstm_video_1a = nn.LSTMCell(2048, hidden_size)
        self.lstm_video_2a = nn.LSTMCell(hidden_size, hidden_size)
        
        self.lstm_video_1m = nn.LSTMCell(4096, hidden_size)
        self.lstm_video_2m = nn.LSTMCell(hidden_size, hidden_size)
        
        self.iter_num = iter_num
        
        if mm_version==1:
            self.lstm_mm_1 = nn.LSTMCell(hidden_size, hidden_size)
            self.lstm_mm_2 = nn.LSTMCell(hidden_size, hidden_size)
            self.linear_decoder_mem = nn.Linear(hidden_size * 2, hidden_size) 
            self.hidden_encoder_1 = nn.Linear(hidden_size * 2, hidden_size) 
            self.hidden_encoder_2 = nn.Linear(hidden_size * 2, hidden_size) 
        
        else:
            self.gru_mm = nn.GRUCell(hidden_size, hidden_size)
            self.linear_decoder_mem = nn.Linear(hidden_size, hidden_size) 
        

        self.mm_att = MultiModalAttentionModule(hidden_size)
        
        self.linear_decoder_att_a = nn.Linear(hidden_size * 2, hidden_size) 
        self.linear_decoder_att_m = nn.Linear(hidden_size * 2, hidden_size) 

        
        if answer_vocab_size is not None:
            self.linear_decoder_count_2 = nn.Linear(hidden_size * 2 + hidden_size, answer_vocab_size)
        else:
            self.linear_decoder_count_2 = nn.Linear(hidden_size * 2 + hidden_size, 1)    # Count is regression problem
                
        self.max_len = max_len


        self.mrm_vid = MemoryRamTwoStreamModule(hidden_size, hidden_size, max_len)
        self.mrm_txt = MemoryRamModule(hidden_size, hidden_size, max_len)

        
        self.init_weights(word_matrix)
Ejemplo n.º 3
0
class AttentionTwoStream(nn.Module):

    def __init__(self, task, feat_channel, feat_dim, text_embed_size, hidden_size, vocab_size, num_layers, word_matrix,
                    answer_vocab_size=None, max_len=20, dropout=0.2, mm_version=1, useSpatial=False, useNaive=False, iter_num=3):
        """Set the hyper-parameters and build the layers."""
        super(AttentionTwoStream, self).__init__()
        
        self.task = task
        
        
        # text input size
        self.text_embed_size = text_embed_size # should be 300
        
        # video input size
        self.feat_channel = feat_channel
        self.feat_dim = feat_dim # should be 7
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.useNaive = useNaive
        self.useSpatial = useSpatial
        self.mm_version = mm_version
        
        self.TpAtt_a = TemporalAttentionModule(hidden_size*2, hidden_size)
        self.TpAtt_m = TemporalAttentionModule(hidden_size*2, hidden_size)

        if useSpatial:
            self.SpAtt = SpatialAttentionModule(feat_channel, feat_dim, hidden_size)
        else:
            self.video_encoder = nn.Linear(feat_channel, hidden_size) 

        self.drop_keep_prob_final_att_vec = nn.Dropout(dropout)
        self.embed = nn.Embedding(vocab_size, text_embed_size)
                    
        self.lstm_text_1 = nn.LSTMCell(text_embed_size, hidden_size)
        self.lstm_text_2 = nn.LSTMCell(hidden_size, hidden_size)

        self.lstm_video_1a = nn.LSTMCell(2048, hidden_size)
        self.lstm_video_2a = nn.LSTMCell(hidden_size, hidden_size)
        
        self.lstm_video_1m = nn.LSTMCell(4096, hidden_size)
        self.lstm_video_2m = nn.LSTMCell(hidden_size, hidden_size)
        
        self.iter_num = iter_num
        
        if mm_version==1:
            self.lstm_mm_1 = nn.LSTMCell(hidden_size, hidden_size)
            self.lstm_mm_2 = nn.LSTMCell(hidden_size, hidden_size)
            self.linear_decoder_mem = nn.Linear(hidden_size * 2, hidden_size) 
            self.hidden_encoder_1 = nn.Linear(hidden_size * 2, hidden_size) 
            self.hidden_encoder_2 = nn.Linear(hidden_size * 2, hidden_size) 
        
        else:
            self.gru_mm = nn.GRUCell(hidden_size, hidden_size)
            self.linear_decoder_mem = nn.Linear(hidden_size, hidden_size) 
        

        self.mm_att = MultiModalAttentionModule(hidden_size)
        
        self.linear_decoder_att_a = nn.Linear(hidden_size * 2, hidden_size) 
        self.linear_decoder_att_m = nn.Linear(hidden_size * 2, hidden_size) 

        
        if answer_vocab_size is not None:
            self.linear_decoder_count_2 = nn.Linear(hidden_size * 2 + hidden_size, answer_vocab_size)
        else:
            self.linear_decoder_count_2 = nn.Linear(hidden_size * 2 + hidden_size, 1)    # Count is regression problem
                
        self.max_len = max_len


        self.mrm_vid = MemoryRamTwoStreamModule(hidden_size, hidden_size, max_len)
        self.mrm_txt = MemoryRamModule(hidden_size, hidden_size, max_len)

        
        self.init_weights(word_matrix)


    def init_weights(self, word_matrix):
        """Initialize weights."""
        
        if word_matrix is None:
            self.embed.weight.data.uniform_(-0.1, 0.1)
        else:
            # init embed from glove
            self.embed.weight.data.copy_(torch.from_numpy(word_matrix))
        
        self.mrm_vid.init_weights()
        self.mrm_txt.init_weights()

    def init_hiddens(self):
        s_t = torch.zeros(1, self.hidden_size).cuda()
        s_t2 = torch.zeros(1, self.hidden_size).cuda()
        c_t = torch.zeros(1, self.hidden_size).cuda()
        c_t2 = torch.zeros(1, self.hidden_size).cuda()
        return s_t,s_t2,c_t,c_t2
    
    
    def mm_module_v1(self,svt_tmp,memory_ram_vid,memory_ram_txt,loop=3):

        sm_q1,sm_q2,cm_q1,cm_q2 = self.init_hiddens()
        mm_oo = self.drop_keep_prob_final_att_vec(torch.tanh(self.hidden_encoder_1(svt_tmp)))
        
        for _ in range(loop):
            
            sm_q1, cm_q1 = self.lstm_mm_1(mm_oo, (sm_q1, cm_q1))
            sm_q2, cm_q2 = self.lstm_mm_2(sm_q1, (sm_q2, cm_q2))
        
            mm_o1 = self.mm_att(sm_q2,memory_ram_vid,memory_ram_txt)
            mm_o2 = torch.cat((sm_q2,mm_o1),dim=1)
            mm_oo = self.drop_keep_prob_final_att_vec(torch.tanh(self.hidden_encoder_2(mm_o2)))
        
        smq = torch.cat( (sm_q1,sm_q2), dim=1)
        return smq
    
    
    def mm_module_v2(self,memory_ram_vid,memory_ram_txt,loop=5):
        
        h_t = torch.zeros(1, self.hidden_size).cuda()
                
        for _ in range(loop):
            mm_o = self.mm_att(h_t,memory_ram_vid,memory_ram_txt)
            h_t = self.gru_mm(mm_o, h_t)
        
        return h_t
    

    def forward(self, data_dict, question_type='Count'):
        if question_type=='Count':
            ret = self.forward_count(data_dict)
        elif question_type=='Action':
            ret = self.forward_action(data_dict)
        elif question_type=='Trans':
            ret = self.forward_trans(data_dict)
        else:
            assert question_type=='FrameQA'
            ret = self.forward_frameqa(data_dict)

        return ret

    def forward_count(self, data_dict):
        
        video_features, numImgs = data_dict['video_features'], data_dict['video_lengths'],
        questions, question_lengths = data_dict['question_words'], data_dict['question_lengths']
        answers = data_dict['answers']           
                
        outputs = []
        predictions = []
        bsize = len(questions)
        batch_size = len(questions)  # batch size has to be 1


        cnt = 0
        features_questions = self.embed(questions)
        
        for j in range(batch_size):
        
            nImg = numImgs[j]
            nQuestionWords = question_lengths[j]

            ################################
            # slice the input image features
            ################################
            feature = video_features[j,video_features.size(1)-nImg:]
            #print('current video feature size', feature.size())            

            
            #############################             
            # run text encoder first time
            #############################
            s1_t1,s1_t2,c1_t1,c1_t2 = self.init_hiddens()
            
            for i in xrange(nQuestionWords):
                input_question = features_questions[j,i:i+1]
                s1_t1, c1_t1 = self.lstm_text_1(input_question, (s1_t1, c1_t1))
                s1_t2, c1_t2 = self.lstm_text_2(s1_t1, (s1_t2, c1_t2))
            
            # here s1_t1, s1_t2 is the last hidden
            s1_t = torch.cat( (s1_t1,s1_t2), dim=1)  # should be of size (1,1024)
            
            
            
            
            ###########################################             
            # run video encoder with spatial attention
            ###########################################
            sV_t1a,sV_t2a,cV_t1a,cV_t2a = s1_t1,s1_t2,c1_t1,c1_t2
            sV_t1m,sV_t2m,cV_t1m,cV_t2m = s1_t1,s1_t2,c1_t1,c1_t2
  
                        
            
            # record each time t, hidden states, for later temporal attention after text encoding
            hidden_array_1a = []
            hidden_array_2a = []
            hidden_array_1m = []
            hidden_array_2m = []
            
            
            for i in xrange(nImg):
                
                if self.useSpatial:
                    input_frame = feature[i:i+1]
                    feat_att = self.SpAtt(input_frame, s1_t)
                else:
                    feat_att_m = feature[i:i+1,0,0,:4096]
                    feat_att_a = feature[i:i+1,0,0,4096:]
                
                sV_t1m, cV_t1m = self.lstm_video_1m(feat_att_m, (sV_t1m, cV_t1m))
                sV_t2m, cV_t2m = self.lstm_video_2m(sV_t1m, (sV_t2m, cV_t2m))

                sV_t1a, cV_t1a = self.lstm_video_1a(feat_att_a, (sV_t1a, cV_t1a))
                sV_t2a, cV_t2a = self.lstm_video_2a(sV_t1a, (sV_t2a, cV_t2a))
                
                
                sV_t1a_vec = sV_t1a.view(sV_t1a.size(0),1,sV_t1a.size(1))
                sV_t2a_vec = sV_t2a.view(sV_t2a.size(0),1,sV_t2a.size(1))
                
                hidden_array_1a.append(sV_t1a_vec)
                hidden_array_2a.append(sV_t2a_vec)
                
                sV_t1m_vec = sV_t1m.view(sV_t1m.size(0),1,sV_t1m.size(1))
                sV_t2m_vec = sV_t2m.view(sV_t2m.size(0),1,sV_t2m.size(1))
                
                hidden_array_1m.append(sV_t1m_vec)
                hidden_array_2m.append(sV_t2m_vec)
                

            # assume sV_t1 is of size (1,1,hidden)
            sV_l1a = torch.cat(hidden_array_1a, dim=1)
            sV_l2a = torch.cat(hidden_array_2a, dim=1)
            sV_l1m = torch.cat(hidden_array_1m, dim=1)
            sV_l2m = torch.cat(hidden_array_2m, dim=1)
            
            sV_lla = torch.cat((sV_l1a,sV_l2a), dim=2)
            sV_llm = torch.cat((sV_l1m,sV_l2m), dim=2)



            
            #############################             
            # run text encoder second time
            #############################
            sT_t1,sT_t2,cT_t1,cT_t2 = self.init_hiddens()
            sT_t1,sT_t2 = sV_t1a+sV_t1m, sV_t2a+sV_t2m
            
            hidden_array_3 = []

            # here sT_t1, sT_t2 are the last hiddens from video, input to text encoder again
            for i in xrange(nQuestionWords):
                input_question = features_questions[j,i:i+1]
                sT_t1, cT_t1 = self.lstm_text_1(input_question, (sT_t1, cT_t1))
                sT_t2, cT_t2 = self.lstm_text_2(sT_t1, (sT_t2, cT_t2))
                hidden_array_3.append(sT_t2)
            #print('Text encoding One size', sT_t1.size(), sT_t2.size())

            # here sT_t1, sT_t2 is the last hidden
            sT_t = torch.cat( (sT_t1,sT_t2), dim=1)  # should be of size (1,1024)
            
            #####################
            # temporal attention
            #####################
            vid_att_a = self.TpAtt_a(sV_lla, sT_t)
            vid_att_m = self.TpAtt_m(sV_llm, sT_t)
            
            
            ################
            # ram memory
            ################
            sT_rl = torch.cat(hidden_array_3, dim=0)

            memory_ram_vid = self.mrm_vid(sV_l2a[0,:,:], sV_l2m[0,:,:], nImg)
            memory_ram_txt = self.mrm_txt(sT_rl, nQuestionWords)
                
            if self.mm_version==1:
                svt_tmp = torch.cat((sV_t2a,sV_t2m),dim=1)
                smq = self.mm_module_v1(svt_tmp,memory_ram_vid,memory_ram_txt,self.iter_num)
            elif self.mm_version==2:
                smq = self.mm_module_v2(memory_ram_vid,memory_ram_txt)
                
                
            ######################### 
            # decode the final output
            #########################
            final_embed_a = torch.tanh( self.linear_decoder_att_a(vid_att_a) )
            final_embed_m = torch.tanh( self.linear_decoder_att_m(vid_att_m) )
            final_embed_2 = torch.tanh( self.linear_decoder_mem(smq) )
            final_embed = torch.cat([final_embed_a,final_embed_m,final_embed_2],dim=1)
            
            output = self.linear_decoder_count_2(final_embed)
            prediction = torch.clamp(torch.round(output.detach()), min=1, max=10).int()
            
            outputs.append(output)
            predictions.append(prediction)

    
        outputs = torch.cat(outputs, 0)
        #targets = torch.cat(targets, 0)
        predictions = torch.cat(predictions, 0)
        #print(predictions.size())
        return outputs, answers, predictions
    
    def forward_action(self, data_dict):

        
        video_features, numImgs = data_dict['video_features'], data_dict['video_lengths'],
        questions, question_lengths = data_dict['candidates'], data_dict['candidate_lengths']
        answers, num_mult_choices = data_dict['answers'], data_dict['num_mult_choices']
                
        outputs = []
        predictions = []
        #print(questions.size())   # (N, 5, 35), 5 multiple choices, each choice is a question of 35 max lengths
        batch_size = len(questions)  

        cnt = 0
        features_questions = self.embed(questions)

        
        for j in range(batch_size):
        
            nImg = numImgs[j]
            outputs_j = []
            predictions_j = []
            
            for n_cand in range(num_mult_choices):
                
                #print ' ', n_cand,
                nQuestionWords = question_lengths[j][n_cand]
                #nAnwserWords = candidate_lengths[j][n_cand]
                

                ################################
                # slice the input image features
                ################################
                feature = video_features[j,video_features.size(1)-nImg:]
                #print('current video feature size', feature.size())            

            
                #############################             
                # run text encoder first time
                #############################
                s1_t1,s1_t2,c1_t1,c1_t2 = self.init_hiddens()
            
                for i in xrange(nQuestionWords):
                    input_question = features_questions[j,n_cand,i:i+1]
                    s1_t1, c1_t1 = self.lstm_text_1(input_question, (s1_t1, c1_t1))
                    s1_t2, c1_t2 = self.lstm_text_2(s1_t1, (s1_t2, c1_t2))
            
                # here s1_t1, s1_t2 is the last hidden
                s1_t = torch.cat( (s1_t1,s1_t2), dim=1)  # should be of size (1,1024)
            
                
            
                ###########################################             
                # run video encoder with spatial attention
                ###########################################
                sV_t1a,sV_t2a,cV_t1a,cV_t2a = s1_t1,s1_t2,c1_t1,c1_t2
                sV_t1m,sV_t2m,cV_t1m,cV_t2m = s1_t1,s1_t2,c1_t1,c1_t2
            

                # record each time t, hidden states, for later temporal attention after text encoding
                hidden_array_1a = []
                hidden_array_2a = []
                hidden_array_1m = []
                hidden_array_2m = []
            
            
                for i in xrange(nImg):               
                    if self.useSpatial:
                        input_frame = feature[i:i+1]
                        feat_att = self.SpAtt(input_frame, s1_t)
                    else:
                        feat_att_m = feature[i:i+1,0,0,:4096]
                        feat_att_a = feature[i:i+1,0,0,4096:]

                    sV_t1m, cV_t1m = self.lstm_video_1m(feat_att_m, (sV_t1m, cV_t1m))
                    sV_t2m, cV_t2m = self.lstm_video_2m(sV_t1m, (sV_t2m, cV_t2m))

                    sV_t1a, cV_t1a = self.lstm_video_1a(feat_att_a, (sV_t1a, cV_t1a))
                    sV_t2a, cV_t2a = self.lstm_video_2a(sV_t1a, (sV_t2a, cV_t2a))
                
                
                    sV_t1a_vec = sV_t1a.view(sV_t1a.size(0),1,sV_t1a.size(1))
                    sV_t2a_vec = sV_t2a.view(sV_t2a.size(0),1,sV_t2a.size(1))
                
                    hidden_array_1a.append(sV_t1a_vec)
                    hidden_array_2a.append(sV_t2a_vec)
                
                    sV_t1m_vec = sV_t1m.view(sV_t1m.size(0),1,sV_t1m.size(1))
                    sV_t2m_vec = sV_t2m.view(sV_t2m.size(0),1,sV_t2m.size(1))
                
                    hidden_array_1m.append(sV_t1m_vec)
                    hidden_array_2m.append(sV_t2m_vec)
                
                

                sV_l1a = torch.cat(hidden_array_1a, dim=1)
                sV_l2a = torch.cat(hidden_array_2a, dim=1)
                sV_l1m = torch.cat(hidden_array_1m, dim=1)
                sV_l2m = torch.cat(hidden_array_2m, dim=1)
            
                sV_lla = torch.cat((sV_l1a,sV_l2a), dim=2)
                sV_llm = torch.cat((sV_l1m,sV_l2m), dim=2)
                

            
                #############################             
                # run text encoder second time
                #############################
                sT_t1,sT_t2,cT_t1,cT_t2 = self.init_hiddens()
                sT_t1,sT_t2 = sV_t1a+sV_t1m, sV_t2a+sV_t2m
            
                hidden_array_3 = []
                
                for i in xrange(nQuestionWords):
                    input_question = features_questions[j,n_cand,i:i+1]
                    sT_t1, cT_t1 = self.lstm_text_1(input_question, (sT_t1, cT_t1))
                    sT_t2, cT_t2 = self.lstm_text_2(sT_t1, (sT_t2, cT_t2))
                    hidden_array_3.append(sT_t2)
                    
                #print('Text encoding One size', sT_t1.size(), sT_t2.size())

                # here sT_t1, sT_t2 is the last hidden
                sT_t = torch.cat( (sT_t1,sT_t2), dim=1)  # should be of size (1,1024)
                
                #####################
                # temporal attention
                #####################
                vid_att_a = self.TpAtt_a(sV_lla, sT_t)
                vid_att_m = self.TpAtt_m(sV_llm, sT_t)
            

                ################
                # ram memory
                ################
                sT_rl = torch.cat(hidden_array_3, dim=0)

                memory_ram_vid = self.mrm_vid(sV_l2a[0,:,:], sV_l2m[0,:,:], nImg)
                memory_ram_txt = self.mrm_txt(sT_rl, nQuestionWords)

                if self.mm_version==1:
                    svt_tmp = torch.cat((sV_t2a,sV_t2m),dim=1)
                    smq = self.mm_module_v1(svt_tmp,memory_ram_vid,memory_ram_txt,self.iter_num)
                elif self.mm_version==2:
                    smq = self.mm_module_v2(memory_ram_vid,memory_ram_txt)
                
                
                ######################### 
                # decode the final output
                ######################### 
                
                final_embed_a = torch.tanh( self.linear_decoder_att_a(vid_att_a) )
                final_embed_m = torch.tanh( self.linear_decoder_att_m(vid_att_m) )
                final_embed_2 = torch.tanh( self.linear_decoder_mem(smq) )
                final_embed = torch.cat([final_embed_a,final_embed_m,final_embed_2],dim=1)
            
                
                output = self.linear_decoder_count_2(final_embed)                
                outputs_j.append(output)
        
            
            # output is the score of each multiple choice
            outputs_j = torch.cat(outputs_j,1)
            #print('Output_j size', outputs_j.size()) # (1,5)
            outputs.append(outputs_j)
            
            # for evaluate accuracy, find the max one
            _,mx_idx = torch.max(outputs_j,1)
            predictions.append(mx_idx)
            #print(outputs_j,mx_idx)
            
        outputs = torch.cat(outputs, 0)
        predictions = torch.cat(predictions, 0)
        return outputs, answers, predictions
           
    
    def forward_trans(self, data_dict):

        
        video_features, numImgs = data_dict['video_features'], data_dict['video_lengths'],
        questions, question_lengths = data_dict['candidates'], data_dict['candidate_lengths']
        answers, num_mult_choices = data_dict['answers'], data_dict['num_mult_choices']
        
                    
        
        outputs = []
        predictions = []
        #print(questions.size())   # (N, 5, 35), 5 multiple choices, each choice is a question of 35 max lengths
        batch_size = len(questions)  

        cnt = 0        
        features_questions = self.embed(questions)

        
        
        for j in range(batch_size):
        
            nImg = numImgs[j]
            outputs_j = []
            predictions_j = []
            
            for n_cand in range(num_mult_choices):
                
                nQuestionWords = question_lengths[j][n_cand]
                #nAnwserWords = answer_lengths[j]

                ################################
                # slice the input image features
                ################################
                feature = video_features[j,video_features.size(1)-nImg:]
                #print('current video feature size', feature.size())            

            
                #############################             
                # run text encoder first time
                #############################
                s1_t1,s1_t2,c1_t1,c1_t2 = self.init_hiddens()
            
                for i in xrange(nQuestionWords):
                    input_question = features_questions[j,n_cand,i:i+1]
                    s1_t1, c1_t1 = self.lstm_text_1(input_question, (s1_t1, c1_t1))
                    s1_t2, c1_t2 = self.lstm_text_2(s1_t1, (s1_t2, c1_t2))
            
                # here s1_t1, s1_t2 is the last hidden
                s1_t = torch.cat( (s1_t1,s1_t2), dim=1)  # should be of size (1,1024)
            
                
            
                ###########################################             
                # run video encoder with spatial attention
                ###########################################
                sV_t1a,sV_t2a,cV_t1a,cV_t2a = s1_t1,s1_t2,c1_t1,c1_t2
                sV_t1m,sV_t2m,cV_t1m,cV_t2m = s1_t1,s1_t2,c1_t1,c1_t2
            
                # record each time t, hidden states, for later temporal attention after text encoding
                hidden_array_1a = []
                hidden_array_2a = []
                hidden_array_1m = []
                hidden_array_2m = []
            
                for i in xrange(nImg):

                    if self.useSpatial:
                        input_frame = feature[i:i+1]
                        feat_att = self.SpAtt(input_frame, s1_t)
                    else:
                        feat_att_m = feature[i:i+1,0,0,:4096]
                        feat_att_a = feature[i:i+1,0,0,4096:]
                        
                                                
                    # lstm
                    sV_t1m, cV_t1m = self.lstm_video_1m(feat_att_m, (sV_t1m, cV_t1m))
                    sV_t2m, cV_t2m = self.lstm_video_2m(sV_t1m, (sV_t2m, cV_t2m))

                    sV_t1a, cV_t1a = self.lstm_video_1a(feat_att_a, (sV_t1a, cV_t1a))
                    sV_t2a, cV_t2a = self.lstm_video_2a(sV_t1a, (sV_t2a, cV_t2a))
                
                    sV_t1a_vec = sV_t1a.view(sV_t1a.size(0),1,sV_t1a.size(1))
                    sV_t2a_vec = sV_t2a.view(sV_t2a.size(0),1,sV_t2a.size(1))
                
                    hidden_array_1a.append(sV_t1a_vec)
                    hidden_array_2a.append(sV_t2a_vec)
                
                    sV_t1m_vec = sV_t1m.view(sV_t1m.size(0),1,sV_t1m.size(1))
                    sV_t2m_vec = sV_t2m.view(sV_t2m.size(0),1,sV_t2m.size(1))
                
                    hidden_array_1m.append(sV_t1m_vec)
                    hidden_array_2m.append(sV_t2m_vec)
                    

                sV_l1a = torch.cat(hidden_array_1a, dim=1)
                sV_l2a = torch.cat(hidden_array_2a, dim=1)
                sV_l1m = torch.cat(hidden_array_1m, dim=1)
                sV_l2m = torch.cat(hidden_array_2m, dim=1)
            
                sV_lla = torch.cat((sV_l1a,sV_l2a), dim=2)
                sV_llm = torch.cat((sV_l1m,sV_l2m), dim=2)

            
                #############################             
                # run text encoder second time
                #############################
                sT_t1,sT_t2,cT_t1,cT_t2 = self.init_hiddens()
                sT_t1,sT_t2 = sV_t1a+sV_t1m, sV_t2a+sV_t2m
                
                hidden_array_3 = []
                
                for i in xrange(nQuestionWords):
                    input_question = features_questions[j,n_cand,i:i+1]
                    sT_t1, cT_t1 = self.lstm_text_1(input_question, (sT_t1, cT_t1))
                    sT_t2, cT_t2 = self.lstm_text_2(sT_t1, (sT_t2, cT_t2))
                    hidden_array_3.append(sT_t2)
                    
                #print('Text encoding One size', sT_t1.size(), sT_t2.size())

                # here sT_t1, sT_t2 is the last hidden
                sT_t = torch.cat( (sT_t1,sT_t2), dim=1)  # should be of size (1,1024)
                
                #####################
                # temporal attention
                #####################
                vid_att_a = self.TpAtt_a(sV_lla, sT_t)
                vid_att_m = self.TpAtt_m(sV_llm, sT_t)


                ################
                # stack memory
                ################
                sT_rl = torch.cat(hidden_array_3, dim=0)

                memory_ram_vid = self.mrm_vid(sV_l2a[0,:,:], sV_l2m[0,:,:], nImg)
                memory_ram_txt = self.mrm_txt(sT_rl, nQuestionWords)

                if self.mm_version==1:
                    svt_tmp = torch.cat((sV_t2a,sV_t2m),dim=1)
                    smq = self.mm_module_v1(svt_tmp,memory_ram_vid,memory_ram_txt,self.iter_num)
                elif self.mm_version==2:
                    smq = self.mm_module_v2(memory_ram_vid,memory_ram_txt)

                                
                ######################### 
                # decode the final output
                ######################### 

                final_embed_a = torch.tanh( self.linear_decoder_att_a(vid_att_a) )
                final_embed_m = torch.tanh( self.linear_decoder_att_m(vid_att_m) )
                final_embed_2 = torch.tanh( self.linear_decoder_mem(smq) )
                final_embed = torch.cat([final_embed_a,final_embed_m,final_embed_2],dim=1)
            

                output = self.linear_decoder_count_2(final_embed)                
                outputs_j.append(output)
        
            
            # output is the score of each multiple choice
            outputs_j = torch.cat(outputs_j,1)
            #print('Output_j size', outputs_j.size()) # (1,5)
            outputs.append(outputs_j)
            
            # for evaluate accuracy, find the max one
            _,mx_idx = torch.max(outputs_j,1)
            predictions.append(mx_idx)
            #print(outputs_j,mx_idx)
            
        outputs = torch.cat(outputs, 0)
        predictions = torch.cat(predictions, 0)
        return outputs, answers, predictions
    
    
    def forward_frameqa(self, data_dict):
        
        video_features, numImgs = data_dict['video_features'], data_dict['video_lengths'],
        questions, question_lengths = data_dict['question_words'], data_dict['question_lengths']
        answers = data_dict['answers']           
                                        
        
        outputs = []
        predictions = []
        bsize = len(questions)
        batch_size = len(questions)  # batch size has to be 1


        cnt = 0
        
        features_questions = self.embed(questions)
        # (64, 35, 300)
        #print('text feature size', features_questions.size())
        
        
        for j in range(batch_size):
        
            nImg = numImgs[j]
            nQuestionWords = question_lengths[j]

            ################################
            # slice the input image features
            ################################
            feature = video_features[j,video_features.size(1)-nImg:]
            #print('current video feature size', feature.size())            

            
            #############################             
            # run text encoder first time
            #############################
            s1_t1,s1_t2,c1_t1,c1_t2 = self.init_hiddens()
            
            for i in xrange(nQuestionWords):
                input_question = features_questions[j,i:i+1]
                s1_t1, c1_t1 = self.lstm_text_1(input_question, (s1_t1, c1_t1))
                s1_t2, c1_t2 = self.lstm_text_2(s1_t1, (s1_t2, c1_t2))
            
            # here s1_t1, s1_t2 is the last hidden
            s1_t = torch.cat( (s1_t1,s1_t2), dim=1)  # should be of size (1,1024)
            
            
            
            ###########################################             
            # run video encoder with spatial attention
            ###########################################
            sV_t1a,sV_t2a,cV_t1a,cV_t2a = s1_t1,s1_t2,c1_t1,c1_t2
            sV_t1m,sV_t2m,cV_t1m,cV_t2m = s1_t1,s1_t2,c1_t1,c1_t2

            # record each time t, hidden states, for later temporal attention after text encoding
            hidden_array_1a = []
            hidden_array_2a = []
            hidden_array_1m = []
            hidden_array_2m = []
            
            for i in xrange(nImg):

                if self.useSpatial:
                    input_frame = feature[i:i+1]
                    feat_att = self.SpAtt(input_frame, s1_t)
                else:
                    feat_att_m = feature[i:i+1,0,0,:4096]
                    feat_att_a = feature[i:i+1,0,0,4096:]
                
                sV_t1m, cV_t1m = self.lstm_video_1m(feat_att_m, (sV_t1m, cV_t1m))
                sV_t2m, cV_t2m = self.lstm_video_2m(sV_t1m, (sV_t2m, cV_t2m))

                sV_t1a, cV_t1a = self.lstm_video_1a(feat_att_a, (sV_t1a, cV_t1a))
                sV_t2a, cV_t2a = self.lstm_video_2a(sV_t1a, (sV_t2a, cV_t2a))
                
                sV_t1a_vec = sV_t1a.view(sV_t1a.size(0),1,sV_t1a.size(1))
                sV_t2a_vec = sV_t2a.view(sV_t2a.size(0),1,sV_t2a.size(1))
            
                hidden_array_1a.append(sV_t1a_vec)
                hidden_array_2a.append(sV_t2a_vec)
            
                sV_t1m_vec = sV_t1m.view(sV_t1m.size(0),1,sV_t1m.size(1))
                sV_t2m_vec = sV_t2m.view(sV_t2m.size(0),1,sV_t2m.size(1))
            
                hidden_array_1m.append(sV_t1m_vec)
                hidden_array_2m.append(sV_t2m_vec)
                

            sV_l1a = torch.cat(hidden_array_1a, dim=1)
            sV_l2a = torch.cat(hidden_array_2a, dim=1)
            sV_l1m = torch.cat(hidden_array_1m, dim=1)
            sV_l2m = torch.cat(hidden_array_2m, dim=1)
        
            sV_lla = torch.cat((sV_l1a,sV_l2a), dim=2)
            sV_llm = torch.cat((sV_l1m,sV_l2m), dim=2)

               
            #############################             
            # run text encoder second time
            #############################
            sT_t1,sT_t2,cT_t1,cT_t2 = self.init_hiddens()
            sT_t1,sT_t2 = sV_t1a+sV_t1m, sV_t2a+sV_t2m
            
            hidden_array_3 = []
            
                
            for i in xrange(nQuestionWords):
                input_question = features_questions[j,i:i+1]
                sT_t1, cT_t1 = self.lstm_text_1(input_question, (sT_t1, cT_t1))
                sT_t2, cT_t2 = self.lstm_text_2(sT_t1, (sT_t2, cT_t2))
                hidden_array_3.append(sT_t2)
            #print('Text encoding One size', sT_t1.size(), sT_t2.size())

            # here sT_t1, sT_t2 is the last hidden
            sT_t = torch.cat( (sT_t1,sT_t2), dim=1)  # should be of size (1,1024)
            
            
            #####################
            # temporal attention
            #####################
            vid_att_a = self.TpAtt_a(sV_lla, sT_t)
            vid_att_m = self.TpAtt_m(sV_llm, sT_t)


            ################
            # ram memory
            ################
            sT_rl = torch.cat(hidden_array_3, dim=0)

            memory_ram_vid = self.mrm_vid(sV_l2a[0,:,:], sV_l2m[0,:,:], nImg)
            memory_ram_txt = self.mrm_txt(sT_rl, nQuestionWords)
                
            if self.mm_version==1:
                svt_tmp = torch.cat((sV_t2a,sV_t2m),dim=1)
                smq = self.mm_module_v1(svt_tmp,memory_ram_vid,memory_ram_txt,self.iter_num)
            elif self.mm_version==2:
                smq = self.mm_module_v2(memory_ram_vid,memory_ram_txt)
                
            ######################### 
            # decode the final output
            ######################### 
            
            final_embed_a = torch.tanh( self.linear_decoder_att_a(vid_att_a) )
            final_embed_m = torch.tanh( self.linear_decoder_att_m(vid_att_m) )
            final_embed_2 = torch.tanh( self.linear_decoder_mem(smq) )
            final_embed = torch.cat([final_embed_a,final_embed_m,final_embed_2],dim=1)
        
                
            output = self.linear_decoder_count_2(final_embed)
            #print('Output size', output.size()) # (1,5)
            outputs.append(output)
            
            _,mx_idx = torch.max(output,1)
            predictions.append(mx_idx)
            #print(output,mx_idx)

    
        outputs = torch.cat(outputs, 0)
        #targets = torch.cat(targets, 0)
        predictions = torch.cat(predictions, 0)
        #print(predictions.size())
        return outputs, answers[:,0], predictions
    

    def accuracy(self, logits, targets):
        correct = torch.sum(logits.eq(targets)).float()
        return correct * 100.0 / targets.size(0)
Ejemplo n.º 4
0
    def __init__(self,
                 opt,
                 feat_channel,
                 feat_dim,
                 text_embed_size,
                 hidden_size,
                 vocab_size,
                 num_layers,
                 word_matrix,
                 answer_vocab_size=None,
                 max_len=20,
                 dropout=0.2,
                 mm_version=1,
                 useSpatial=False,
                 useNaive=False,
                 mrmUseOriginFeat=False,
                 iter_num=3):
        """Set the hyper-parameters and build the layers."""
        super(AttentionTwoStream, self).__init__()

        # text input size
        self.text_embed_size = text_embed_size  # should be 300

        # video input size
        self.feat_channel = feat_channel
        self.feat_dim = feat_dim  # should be 7

        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.useNaive = useNaive
        self.mrmUseOriginFeat = mrmUseOriginFeat
        self.useSpatial = useSpatial
        self.mm_version = mm_version
        self.iter_num = iter_num

        self.TpAtt_a = TemporalAttentionModule(hidden_size * 2, hidden_size)
        self.TpAtt_m = TemporalAttentionModule(hidden_size * 2, hidden_size)

        if useSpatial:
            self.SpAtt = SpatialAttentionModule(feat_channel, feat_dim,
                                                hidden_size)
        else:
            self.video_encoder = nn.Linear(feat_channel, hidden_size)

        self.drop_keep_prob_final_att_vec = nn.Dropout(dropout)
        self.embed = nn.Embedding(vocab_size, text_embed_size)
        ####
        self.opt = opt
        # self.bert = BertModel.from_pretrained('bert-base-uncased')
        # for param in self.bert.parameters():
        #     param.requires_grad = False
        # self.bert_fc = nn.Sequential(
        #             nn.Dropout(0.5),
        #             nn.Linear(768, 300),
        #             nn.Tanh(),
        #         )
        ####

        self.lstm_text_1 = nn.LSTMCell(text_embed_size, hidden_size)
        self.lstm_text_2 = nn.LSTMCell(hidden_size, hidden_size)

        self.lstm_video_1a = nn.LSTMCell(4096, hidden_size)
        self.lstm_video_2a = nn.LSTMCell(hidden_size, hidden_size)

        self.lstm_video_1m = nn.LSTMCell(4096, hidden_size)
        self.lstm_video_2m = nn.LSTMCell(hidden_size, hidden_size)

        if mm_version == 1:
            self.lstm_mm_1 = nn.LSTMCell(hidden_size, hidden_size)
            self.lstm_mm_2 = nn.LSTMCell(hidden_size, hidden_size)
            self.linear_decoder_mem = nn.Linear(hidden_size * 2, hidden_size)
            self.hidden_encoder_1 = nn.Linear(hidden_size * 2, hidden_size)
            self.hidden_encoder_2 = nn.Linear(hidden_size * 2, hidden_size)

        else:
            self.gru_mm = nn.GRUCell(hidden_size, hidden_size)
            self.linear_decoder_mem = nn.Linear(hidden_size, hidden_size)

        self.mm_att = MultiModalAttentionModule(opt, hidden_size=hidden_size)
        self.linear_decoder_att_a = nn.Linear(hidden_size * 2, hidden_size)
        self.linear_decoder_att_m = nn.Linear(hidden_size * 2, hidden_size)

        if answer_vocab_size is not None:
            self.linear_decoder_count_2 = nn.Linear(
                hidden_size * 2 + hidden_size, answer_vocab_size)
        else:
            self.linear_decoder_count_2 = nn.Linear(
                hidden_size * 2 + hidden_size,
                1)  # Count is regression problem

        self.max_len = max_len

        self.mrm_vid = MemoryRamTwoStreamModule(hidden_size, hidden_size,
                                                max_len)
        self.mrm_txt = MemoryRamModule(hidden_size, hidden_size, max_len)

        self.init_weights(word_matrix)
Ejemplo n.º 5
0
class AttentionTwoStream(nn.Module):
    # Args renamed to opt for convenience
    def __init__(self,
                 opt,
                 feat_channel,
                 feat_dim,
                 text_embed_size,
                 hidden_size,
                 vocab_size,
                 num_layers,
                 word_matrix,
                 answer_vocab_size=None,
                 max_len=20,
                 dropout=0.2,
                 mm_version=1,
                 useSpatial=False,
                 useNaive=False,
                 mrmUseOriginFeat=False,
                 iter_num=3):
        """Set the hyper-parameters and build the layers."""
        super(AttentionTwoStream, self).__init__()

        # text input size
        self.text_embed_size = text_embed_size  # should be 300

        # video input size
        self.feat_channel = feat_channel
        self.feat_dim = feat_dim  # should be 7

        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.useNaive = useNaive
        self.mrmUseOriginFeat = mrmUseOriginFeat
        self.useSpatial = useSpatial
        self.mm_version = mm_version
        self.iter_num = iter_num

        self.TpAtt_a = TemporalAttentionModule(hidden_size * 2, hidden_size)
        self.TpAtt_m = TemporalAttentionModule(hidden_size * 2, hidden_size)

        if useSpatial:
            self.SpAtt = SpatialAttentionModule(feat_channel, feat_dim,
                                                hidden_size)
        else:
            self.video_encoder = nn.Linear(feat_channel, hidden_size)

        self.drop_keep_prob_final_att_vec = nn.Dropout(dropout)
        self.embed = nn.Embedding(vocab_size, text_embed_size)
        ####
        self.opt = opt
        # self.bert = BertModel.from_pretrained('bert-base-uncased')
        # for param in self.bert.parameters():
        #     param.requires_grad = False
        # self.bert_fc = nn.Sequential(
        #             nn.Dropout(0.5),
        #             nn.Linear(768, 300),
        #             nn.Tanh(),
        #         )
        ####

        self.lstm_text_1 = nn.LSTMCell(text_embed_size, hidden_size)
        self.lstm_text_2 = nn.LSTMCell(hidden_size, hidden_size)

        self.lstm_video_1a = nn.LSTMCell(4096, hidden_size)
        self.lstm_video_2a = nn.LSTMCell(hidden_size, hidden_size)

        self.lstm_video_1m = nn.LSTMCell(4096, hidden_size)
        self.lstm_video_2m = nn.LSTMCell(hidden_size, hidden_size)

        if mm_version == 1:
            self.lstm_mm_1 = nn.LSTMCell(hidden_size, hidden_size)
            self.lstm_mm_2 = nn.LSTMCell(hidden_size, hidden_size)
            self.linear_decoder_mem = nn.Linear(hidden_size * 2, hidden_size)
            self.hidden_encoder_1 = nn.Linear(hidden_size * 2, hidden_size)
            self.hidden_encoder_2 = nn.Linear(hidden_size * 2, hidden_size)

        else:
            self.gru_mm = nn.GRUCell(hidden_size, hidden_size)
            self.linear_decoder_mem = nn.Linear(hidden_size, hidden_size)

        self.mm_att = MultiModalAttentionModule(opt, hidden_size=hidden_size)
        self.linear_decoder_att_a = nn.Linear(hidden_size * 2, hidden_size)
        self.linear_decoder_att_m = nn.Linear(hidden_size * 2, hidden_size)

        if answer_vocab_size is not None:
            self.linear_decoder_count_2 = nn.Linear(
                hidden_size * 2 + hidden_size, answer_vocab_size)
        else:
            self.linear_decoder_count_2 = nn.Linear(
                hidden_size * 2 + hidden_size,
                1)  # Count is regression problem

        self.max_len = max_len

        self.mrm_vid = MemoryRamTwoStreamModule(hidden_size, hidden_size,
                                                max_len)
        self.mrm_txt = MemoryRamModule(hidden_size, hidden_size, max_len)

        self.init_weights(word_matrix)

    def init_weights(self, word_matrix):
        """Initialize weights."""

        if word_matrix is None:
            self.embed.weight.data.uniform_(-0.1, 0.1)
        else:
            # init embed from glove
            self.embed.weight.data.copy_(torch.from_numpy(word_matrix))

        self.mrm_vid.init_weights()
        self.mrm_txt.init_weights()

    def init_hiddens(self):
        s_t = torch.zeros(1, self.hidden_size).cuda()
        s_t2 = torch.zeros(1, self.hidden_size).cuda()
        c_t = torch.zeros(1, self.hidden_size).cuda()
        c_t2 = torch.zeros(1, self.hidden_size).cuda()
        return s_t, s_t2, c_t, c_t2

    def mm_module_v1(self, svt_tmp, memory_ram_vid, memory_ram_txt, loop=3):

        sm_q1, sm_q2, cm_q1, cm_q2 = self.init_hiddens()
        mm_oo = self.drop_keep_prob_final_att_vec(
            torch.tanh(self.hidden_encoder_1(svt_tmp)))

        for _ in range(loop):

            sm_q1, cm_q1 = self.lstm_mm_1(mm_oo, (sm_q1, cm_q1))
            sm_q2, cm_q2 = self.lstm_mm_2(sm_q1, (sm_q2, cm_q2))

            mm_o1 = self.mm_att(sm_q2, memory_ram_vid, memory_ram_txt)
            mm_o2 = torch.cat((sm_q2, mm_o1), dim=1)
            mm_oo = self.drop_keep_prob_final_att_vec(
                torch.tanh(self.hidden_encoder_2(mm_o2)))

        smq = torch.cat((sm_q1, sm_q2), dim=1)

        return smq

    def mm_module_v2(self, memory_ram_vid, memory_ram_txt, loop=5):

        h_t = torch.zeros(1, self.hidden_size).cuda()

        for _ in range(loop):
            mm_o = self.mm_att(h_t, memory_ram_vid, memory_ram_txt)
            h_t = self.gru_mm(mm_o, h_t)

        return h_t

    def forward(self, data_dict):
        ret = self.forward_frameqa(data_dict)
        return ret

    def forward_frameqa(self, data_dict):
        video_features, numImgs = data_dict['video_features'], data_dict[
            'video_lengths'],
        questions, question_lengths = data_dict['question_words'], data_dict[
            'question_lengths']

        outputs = []
        predictions = []
        bsize = len(questions)
        batch_size = len(questions)  # batch size has to be 1
        ########
        #Flatten the batch and 5 candidates, process them with bert and FC layer and then recover them
        features_questions = self.embed(questions)
        # import ipdb; ipdb.set_trace()
        # questions = questions.view(-1, questions.shape[-1])
        # features_questions = self.bert(questions)[0][11]    # Final bert layer
        # features_questions = self.bert_fc(features_questions)
        # features_questions = features_questions.view(batch_size, 5, questions.shape[-1], 300)
        ########

        for j in range(batch_size):
            nImg = numImgs[j]
            nQuestionWords = question_lengths[j]

            ################################
            # slice the input image features
            ################################
            feature = video_features[j, video_features.size(1) - nImg:]
            #print('current video feature size', feature.size())

            #############################
            # run text encoder first time
            #############################
            s1_t1, s1_t2, c1_t1, c1_t2 = self.init_hiddens()

            for i in xrange(nQuestionWords):
                input_question = features_questions[j, i:i + 1]
                s1_t1, c1_t1 = self.lstm_text_1(input_question, (s1_t1, c1_t1))
                s1_t2, c1_t2 = self.lstm_text_2(s1_t1, (s1_t2, c1_t2))

            # here s1_t1, s1_t2 is the last hidden
            s1_t = torch.cat((s1_t1, s1_t2),
                             dim=1)  # should be of size (1,1024)

            ###########################################
            # run video encoder with spatial attention
            ###########################################
            sV_t1a, sV_t2a, cV_t1a, cV_t2a = s1_t1, s1_t2, c1_t1, c1_t2
            sV_t1m, sV_t2m, cV_t1m, cV_t2m = s1_t1, s1_t2, c1_t1, c1_t2

            # record each time t, hidden states, for later temporal attention after text encoding
            hidden_array_1a = []
            hidden_array_2a = []
            hidden_array_1m = []
            hidden_array_2m = []

            for i in xrange(nImg):

                if self.useSpatial:
                    input_frame = feature[i:i + 1]
                    feat_att = self.SpAtt(input_frame, s1_t)
                else:
                    feat_att_m = feature[i:i + 1, 0, 0, :4096]
                    feat_att_a = feature[i:i + 1, 0, 0, 4096:]

                sV_t1m, cV_t1m = self.lstm_video_1m(feat_att_m,
                                                    (sV_t1m, cV_t1m))
                sV_t2m, cV_t2m = self.lstm_video_2m(sV_t1m, (sV_t2m, cV_t2m))

                sV_t1a, cV_t1a = self.lstm_video_1a(feat_att_a,
                                                    (sV_t1a, cV_t1a))
                sV_t2a, cV_t2a = self.lstm_video_2a(sV_t1a, (sV_t2a, cV_t2a))

                sV_t1a_vec = sV_t1a.view(sV_t1a.size(0), 1, sV_t1a.size(1))
                sV_t2a_vec = sV_t2a.view(sV_t2a.size(0), 1, sV_t2a.size(1))

                hidden_array_1a.append(sV_t1a_vec)
                hidden_array_2a.append(sV_t2a_vec)

                sV_t1m_vec = sV_t1m.view(sV_t1m.size(0), 1, sV_t1m.size(1))
                sV_t2m_vec = sV_t2m.view(sV_t2m.size(0), 1, sV_t2m.size(1))

                hidden_array_1m.append(sV_t1m_vec)
                hidden_array_2m.append(sV_t2m_vec)

            sV_l1a = torch.cat(hidden_array_1a, dim=1)
            sV_l2a = torch.cat(hidden_array_2a, dim=1)
            sV_l1m = torch.cat(hidden_array_1m, dim=1)
            sV_l2m = torch.cat(hidden_array_2m, dim=1)

            sV_lla = torch.cat((sV_l1a, sV_l2a), dim=2)
            sV_llm = torch.cat((sV_l1m, sV_l2m), dim=2)

            #############################
            # run text encoder second time
            #############################
            sT_t1, sT_t2, cT_t1, cT_t2 = self.init_hiddens()
            sT_t1, sT_t2 = sV_t1a + sV_t1m, sV_t2a + sV_t2m

            hidden_array_3 = []

            for i in xrange(nQuestionWords):
                input_question = features_questions[j, i:i + 1]
                sT_t1, cT_t1 = self.lstm_text_1(input_question, (sT_t1, cT_t1))
                sT_t2, cT_t2 = self.lstm_text_2(sT_t1, (sT_t2, cT_t2))
                hidden_array_3.append(sT_t2)

            # here sT_t1, sT_t2 is the last hidden
            sT_t = torch.cat((sT_t1, sT_t2),
                             dim=1)  # should be of size (1,1024)

            #####################
            # temporal attention
            #####################
            vid_att_a = self.TpAtt_a(sV_lla, sT_t)
            vid_att_m = self.TpAtt_m(sV_llm, sT_t)

            ################
            # ram memory
            ################
            sT_rl = torch.cat(hidden_array_3, dim=0)

            memory_ram_vid = self.mrm_vid(sV_l2a[0, :, :], sV_l2m[0, :, :],
                                          nImg)
            memory_ram_txt = self.mrm_txt(sT_rl, nQuestionWords)

            svt_tmp = torch.cat((sV_t2a, sV_t2m), dim=1)
            smq = self.mm_module_v1(svt_tmp, memory_ram_vid, memory_ram_txt,
                                    self.iter_num)

            #########################
            # decode the final output
            #########################

            final_embed_a = torch.tanh(self.linear_decoder_att_a(vid_att_a))
            final_embed_m = torch.tanh(self.linear_decoder_att_m(vid_att_m))
            final_embed_2 = torch.tanh(self.linear_decoder_mem(smq))
            final_embed = torch.cat(
                [final_embed_a, final_embed_m, final_embed_2], dim=1)

            output = self.linear_decoder_count_2(final_embed)
            outputs.append(output)

            _, mx_idx = torch.max(output, 1)
            predictions.append(mx_idx)

        outputs = torch.cat(outputs, 0)
        predictions = torch.cat(predictions, 0)
        return outputs, predictions

    def accuracy(self, logits, targets):
        correct = torch.sum(logits.eq(targets)).float()
        return correct * 100.0 / targets.size(0)