def forward(self, isequence, ilenghts, queryvector=None): #isequence : batch_size*num_steps*input_emb #ilenghts : batch_size if isequence.dim() == 5: bs, nf, c, h, w = isequence.size() isequence = isequence.view(bs, nf, c, -1) isequence = isequence.sum(dim=-1) isequence = isequence / (h * w) isequence = self.dropout_layer(isequence) #isequence = functional.relu(self.linear(isequence)) contextvector = None imask = Variable(utils.sequence_mask(ilenghts)) if self.configuration == 'MP': if self.contextvector is not None: contextvector = self.contextvector else: #print("MP") contextvector = isequence.sum(dim=1) contextvector = contextvector.div( ilenghts.unsqueeze(1).float()) #contextvector : batch_size*input_emb self.contextvector = contextvector if self.recurrent: if self.contextvector is not None and self.inputvector is not None: contextvector = self.contextvector isequence = self.inputvector else: #print("Recurrent") isequence, ihidden, _ = self.recurrent_layer( isequence, ilenghts) #isequence : batch_size*num_steps*hidden_dim #ihidden : batch_size*hidden_dim contextvector = ihidden self.contextvector = contextvector self.inputvector = isequence if self.attention: #print("Attention") vector_sequence_attention, sequence_attention_weights = self.attention_function( isequence, queryvector, imask, softmax=True) contextvector = vector_sequence_attention #contextvector : batch_size*input_emb/hidden_dim return contextvector
def forward(self, videoframes, videoframes_lengths, inputwords, captionwords_lengths): #videoframes : batch_size*num_frames*3*224*224 #videoframes_lengths : batch_size #inputwords : batch_size*num_words #outputwords : batch_size*num_words #captionwords_lengths : batch_size videoframes_mask = Variable(utils.sequence_mask(videoframes_lengths)) #videoframes_mask: batch_size*num_frames #redundant because we are masking the loss #captionwords_mask = Variable(utils.sequence_mask(captionwords_lengths)) #captionwords_mask: batch_size*num_words batch_size, num_frames, rgb, height, width = videoframes.size() videoframes = videoframes.view(-1,rgb,height,width) #videoframes : batch_size.num_frames*3*224*224 videoframefeatures = self.pretrained_vision_layer(videoframes) videoframefeatures_fc = videoframefeatures[1] #videoframefeatures_fc : batch_size.num_frames*1000 videoframefeatures_fc = self.vision_feature_dimred_layer(videoframefeatures_fc) #videoframefeatures_fc : batch_size.num_frames*rnn_hdim _, feature_dim = videoframefeatures_fc.size() videoframefeatures_fc = videoframefeatures_fc.view(batch_size, num_frames, feature_dim) #videoframefeatures_fc : batch_size*num_frames*1000 videoframefeatures_fc = utils.mask_sequence(videoframefeatures_fc, videoframes_mask) videoframefeatures_fcmeanpooling = videoframefeatures_fc.sum(dim = 1) videoframefeatures_fcmeanpooling = videoframefeatures_fcmeanpooling.div(videoframes_lengths.unsqueeze(1).float()) inputword_vectors = self.pretrained_words_layer(inputwords) #inputword_vectors: batch_size*num_words*wembed_dim #outputword_values = self.sentence_decoder_layer(inputword_vectors, videoframefeatures_fcmeanpooling, captionwords_mask) outputword_values = self.sentence_decoder_layer(inputword_vectors, videoframefeatures_fcmeanpooling) #outputword_values = batch_size*num_words*vocab_size outputword_log_probabilities = functional.log_softmax(outputword_values, dim=2) #outputword_values = batch_size*num_words*vocab_size #outputword_log_probabilities = utils.mask_sequence(outputword_log_probabilities, captionwords_mask) return outputword_log_probabilities #batch_size*num_words*vocab_size
return sequence1_sequence2_attention, sequence2_attention_weights if __name__ == '__main__': #bidir = BiDirAttention({'similarity_function': 'DotProduct', 'one_shot_attention':True}) bidir = BiDirAttention({ 'similarity_function': 'WeightedSumProjection', 'sequence1_dim': 10, 'sequence2_dim': 10, 'projection_dim': 10, 'one_shot_attention': True, 'self_match_attention': False }) sequence1_sequence2_attention, sequence2_attention_weights, sequence2_sequence1_attention, sequence1_attention_weights =\ bidir(Variable(torch.randn(2,5,10)), Variable(torch.randn(2,3,10)),\ Variable(utils.sequence_mask(torch.LongTensor([5,3]))), Variable(utils.sequence_mask(torch.LongTensor([3,2])))) print(sequence1_sequence2_attention) print(sequence2_attention_weights) print(sequence2_sequence1_attention) print(sequence1_attention_weights) '''bidir = BiDirAttention({'similarity_function': 'WeightedSumProjection', 'sequence1_dim':7, 'sequence2_dim':7, 'projection_dim':7, 'one_shot_attention':False, 'self_match_attention':True}) sequence = Variable(torch.randn(2,5,7)) sequence_sequence_attention, sequence_attention_weights = bidir(sequence, sequence,\ Variable(utils.sequence_mask(torch.LongTensor([5,3]))), Variable(utils.sequence_mask(torch.LongTensor([5,3])))) print(sequence_sequence_attention) print(sequence_attention_weights)''' '''bidir = BiDirAttention({'similarity_function': 'WeightedInputsConcatenation', 'input_dim': 10}) print(bidir(Variable(torch.randn(2,5,10)), Variable(torch.randn(2,3,10)))) bidir = BiDirAttention({'similarity_function': 'WeightedInputsDotConcatenation', 'input_dim': 10}) print(bidir(Variable(torch.randn(2,5,10)), Variable(torch.randn(2,3,10))))'''
def train(): cur_dir = os.getcwd() input_dir = 'input' glove_dir = 'glove/' glove_filename = 'glove.6B.300d.txt' glove_embdim = 300 glove_filepath = os.path.join(glove_dir, glove_filename) data_parallel = False frame_trunc_length = 45 train_batch_size = 32 train_num_workers = 0 train_pretrained = True train_pklexist = True eval_batch_size = 1 print("Get train data...") spatial = True spatialpool = False #train_pkl_file = 'MSRVTT/Pixel/Resnet1000/trainvideo.pkl' if not spatial: train_pkl_file = 'MSRVTT/Pixel/Alexnet1000/trainvideo.pkl' else: train_pkl_file = 'MSRVTT/Pixel/Resnet51222/trainvideo.pkl' if spatialpool: train_pkl_file = 'MSRVTT/Pixel/Alexnet25622/trainvideo.pkl' file_names = [('MSRVTT/captions.json', 'MSRVTT/trainvideo.json', 'MSRVTT/Frames')] files = [[os.path.join(cur_dir, input_dir, filetype) for filetype in file] for file in file_names] train_pkl_path = os.path.join(cur_dir, input_dir, train_pkl_file) train_dataloader, vocab, glove, train_data_size = loader.get_train_data( files, train_pkl_path, glove_filepath, glove_embdim, batch_size=train_batch_size, num_workers=train_num_workers, pretrained=train_pretrained, pklexist=train_pklexist, data_parallel=data_parallel, frame_trunc_length=frame_trunc_length, spatial=spatial or spatialpool) # print("Get validation data...") # file_names = [('MSRVTT/captions.json', 'MSRVTT/valvideo.json.sample', 'MSRVTT/Frames')] # files = [[os.path.join(cur_dir, input_dir, filetype) for filetype in file] for file in file_names] # val_dataloader = loader.get_val_data(files, vocab, glove, eval_batch_size) modelname = 'FinalDropoutResnet020203NoResLinear' #modelname = 'FinalDropout000000Res' if spatial: modeltype = 'stal' else: modeltype = 'csal' save_dir = 'models/{}/'.format(modelname + modeltype) save_dir_path = os.path.join(cur_dir, save_dir) if not os.path.exists(save_dir_path): os.makedirs(save_dir_path) glovefile = open(os.path.join(save_dir, 'glove.pkl'), 'wb') pickle.dump(glove, glovefile) glovefile.close() vocabfile = open(os.path.join(save_dir, 'vocab.pkl'), 'wb') pickle.dump(vocab, vocabfile) vocabfile.close() print(save_dir) pretrained_wordvecs = glove.index2vec #model_name = MP, MPAttn, LSTM, LSTMAttn for CSAL '''hidden_dimension = 256 #glove_embdim dict_args = { "intermediate_layers" : ['layer4', 'fc'], "pretrained_feature_size" : 1000, # "word_embeddings" : pretrained_wordvecs, "word_embdim" : glove_embdim, "vocabulary_size" : len(pretrained_wordvecs), "use_pretrained_emb" : True, "backprop_embeddings" : False, # "encoder_configuration" : 'LSTMAttn', "encoder_input_dim" : hidden_dimension, "encoder_rnn_type" : 'LSTM', "encoder_rnn_hdim" : hidden_dimension, "encoder_num_layers" : 1, "encoder_dropout_rate" : 0.2, "encoderattn_projection_dim" : hidden_dimension/2, "encoderattn_query_dim" : hidden_dimension, # "decoder_rnn_input_dim" : glove_embdim + hidden_dimension, #"decoder_rnn_input_dim" : glove_embdim, "decoder_dropout_rate" : 0.2, "decoder_rnn_hidden_dim" : hidden_dimension, "decoder_tie_weights" : False, "decoder_rnn_type" : 'LSTM', "every_step": True, #"every_step": False }''' #SpatialTemporal, LSTMTrackSpatial, LSTMTrackSpatialTemporal hidden_dim = 256 #256 dict_args = { "word_embeddings": pretrained_wordvecs, "word_embdim": glove_embdim, "use_pretrained_emb": True, "backprop_embeddings": False, "vocabulary_size": len(pretrained_wordvecs), "encoder_configuration": 'LSTMTrackSpatialTemporal', "frame_channel_dim": 512, "frame_spatial_dim": 2, "encoder_rnn_type": 'LSTM', "frame_channelred_dim": hidden_dim, "encoder_rnn_hdim": hidden_dim, "encoder_dropout_rate": 0.2, "encoderattn_projection_dim": hidden_dim / 2, "encoderattn_query_dim": hidden_dim, "encoder_linear": True, "decoder_rnn_word_dim": glove_embdim, #"decoder_rnn_input_dim" : hidden_dim + 256, "decoder_rnn_input_dim": hidden_dim + hidden_dim, "decoder_rnn_hidden_dim": hidden_dim, "decoder_rnn_type": 'LSTM', "decoder_top_dropout_rate": 0.3, "decoder_bottom_dropout_rate": 0.2, "residual_connection": False, "every_step": True } if not spatial: csal = CSAL(dict_args) else: csal = STAL(dict_args) print(dict_args) num_epochs = 500 learning_rate = 1 criterion = nn.NLLLoss(reduce=False) optimizer = optim.Adadelta(filter(lambda p: p.requires_grad, csal.parameters()), lr=learning_rate, rho=0.9, eps=1e-06, weight_decay=0) if USE_CUDA: if data_parallel: csal = nn.DataParallel(csal).cuda() else: csal = csal.cuda() criterion = criterion.cuda() print("Start training...") for epoch in range(num_epochs): start_time = time.time() for i, batch in enumerate(train_dataloader): load_time = time.time() #######Load Data padded_imageframes_batch = Variable(torch.stack( batch[0])) #batch_size*num_frames*3*224*224 frame_sequence_lengths = Variable(torch.LongTensor( batch[1])) #batch_size padded_inputwords_batch = Variable(torch.LongTensor( batch[2])) #batch_size*num_words input_sequence_lengths = Variable(torch.LongTensor( batch[3])) #batch_size padded_outputwords_batch = Variable(torch.LongTensor( batch[4])) #batch_size*num_words output_sequence_lengths = Variable(torch.LongTensor( batch[5])) #batch_size video_ids_list = batch[6] captionwords_mask = Variable( utils.sequence_mask( output_sequence_lengths)) #batch_size*num_words if USE_CUDA: async = data_parallel padded_imageframes_batch = padded_imageframes_batch.cuda( async=async) frame_sequence_lengths = frame_sequence_lengths.cuda( async=async)
2) #sequence: batch_size*num_words*iembed2 sequence_attention_weights = utils.masked_softmax( similarity_vector, sequence_mask.float()) vector_sequence_attention = utils.attention_pooling( sequence_attention_weights, sequence) #vector_sequence_attention: batch_size*iembed2 #Will it save some memory? if self.training: sequence_attention_weights = None if softmax: return vector_sequence_attention, sequence_attention_weights else: return vector_sequence_attention, similarity_vector if __name__ == '__main__': unidir = UniDirAttention({ 'similarity_function': 'WeightedSumProjection', 'sequence1_dim': 6, 'sequence2_dim': 20, 'projection_dim': 10 }) vector_sequence_attention, sequence_attention_weights = unidir(Variable(torch.randn(2,5,6)), Variable(torch.randn(1,20).expand(2,20)),\ Variable(utils.sequence_mask(torch.LongTensor([5,3]))), softmax=True) print(vector_sequence_attention) print(sequence_attention_weights)
if not self.training: sequence_selfattn_weights[ ith_item] = sequence_attention_weights sequence_sequence_selfattn = sequence_sequence_selfattn.permute( 1, 0, 2) if not self.training: sequence_selfattn_weights = sequence_selfattn_weights.permute( 1, 0, 2) #sequence_mask: batch_size*num_words sequence_sequence_selfattn = utils.mask_sequence( sequence_sequence_selfattn, sequence_mask) if not self.training: sequence_selfattn_weights = utils.mask_sequence( sequence_selfattn_weights, sequence_mask) #sequence_sequence_selfattn: batch_size*num_words*iembed #sequence_selfattn_weights: batch_size*num_words*num_words return sequence_sequence_selfattn, sequence_selfattn_weights if __name__ == '__main__': selfattn = SelfAttention({ 'similarity_function': 'WeightedSumProjection', 'sequence_dim': 10, 'projection_dim': 10 }) sequence_sequence_selfattn, sequence_selfattn_weights =\ selfattn(Variable(torch.randn(3,6,10)), Variable(utils.sequence_mask(torch.LongTensor([5,3,6])))) print(sequence_sequence_selfattn) print(sequence_selfattn_weights)
h_t) #h_t: batch_size*hidden_dim elif self.rnn_type == 'RNN': pass pointer_sequence_attention_weights[ step] = sequence_attention_weights pointer_sequence_attention_weights = pointer_sequence_attention_weights.permute( 1, 0, 2) return pointer_sequence_attention_weights #batch_size*num_steps*num_words if __name__ == '__main__': '''ptrnet = PointerNetwork({'similarity_function': 'WeightedSumProjection', 'sequence_dim':15, 'projection_dim':10, 'rnn_type':'GRU', 'rnn_hdim':15}) pointer_sequence_attention_weights =\ ptrnet(Variable(torch.randn(3,6,15)), Variable(torch.randn(3,15)),\ Variable(utils.sequence_mask(torch.LongTensor([5,3,6]))), 2) print(pointer_sequence_attention_weights)''' ptrnet = PointerNetwork({ 'similarity_function': 'WeightedSumProjection', 'sequence_dim': 15, 'projection_dim': 10, 'rnn_type': 'GRU', 'rnn_hdim': 15 }) pointer_sequence_attention_weights =\ ptrnet(Variable(torch.randn(3,6,15)), Variable(torch.randn(1,15).expand(3,15)),\ Variable(utils.sequence_mask(torch.LongTensor([5,3,6]))), 2) print(pointer_sequence_attention_weights)
def forward(self, passage, passage_lengths, passagechars, passagechar_lengths,\ question, question_lengths, questionchars, questionchar_lengths): #passage: batch_size*num_words_passage*wembdim #passage_lengths: batch_size #passagechars: batch_size*num_words_passage*num_max_chars #passagechar_lengths: batch_size*num_words_passage #question: batch_size*num_words_question*wembdim #question_lengths: batch_size #questionchars: batch_size*num_words_question*num_max_chars #questionchar_lengths: batch_size*num_words_question passage_embedding = passage[:, 0:passage_lengths.max()] question_embedding = question[:, 0:question_lengths.max()] passage_mask = Variable(utils.sequence_mask( passage_lengths)) #passage_mask: batch_size*num_words_passage question_mask = Variable(utils.sequence_mask( question_lengths)) #question_mask: batch_size*num_words_question ##### Character Embedding Layer if self.use_charemb: passage_char_embedding = self.charemb_layer( passagechars, passagechar_lengths) #passage_char_embedding: batch_size*num_words_passage*2.charemb_rnn_hdim question_char_embedding = self.charemb_layer( questionchars, questionchar_lengths) #question_char_embedding: batch_size*num_words_question*2.charemb_rnn_hdim ##### Char and Word Embedding Concatenation passage_embedding = torch.cat((passage, passage_char_embedding), dim=2) question_embedding = torch.cat((question, question_char_embedding), dim=2) passage_embedding = passage_embedding[:, 0:passage_lengths.max()] question_embedding = question_embedding[:, 0:question_lengths.max()] passage_embedding = utils.mask_sequence(passage_embedding, passage_mask) question_embedding = utils.mask_sequence(question_embedding, question_mask) #passage_embedding: batch_size*num_words_passage*(2.charemb_rnn_hdim+wembdim) #question_embedding: batch_size*num_words_question*(2.charemb_rnn_hdim+wembdim) ##### Context Embedding Layer #passage_embedding, passage_lengths = self.contextemb_layer(passage_embedding, passage_lengths) #question_embedding, question_lengths = self.contextemb_layer(question_embedding, question_lengths) passage_embedding, _ = self.contextemb_layer(passage_embedding, passage_lengths) question_embedding, _ = self.contextemb_layer(question_embedding, question_lengths) #passage_embedding: batch_size*num_words_passage*2.contextemb_rnn_hdim #question_embedding: batch_size*num_words_question*2.contextemb_rnn_hdim #Skipping recomputation of passage_mask and question_mask ##### BiDAF Layer passage_question_attention, question_attention_weights, question_passage_attention, passage_attention_weights =\ self.bidaf_layer(passage_embedding, question_embedding, passage_mask, question_mask) #passage_question_attention: batch_size*num_words_passage*2.contextemb_rnn_hdim #question_attention_weights: batch_size*num_words_passage*num_words_question #question_passage_attention: batch_size*2.contextemb_rnn_hdim #passage_attention_weights: batch_size*num_words_passage question_passage_attention = question_passage_attention.unsqueeze( 1) #question_passage_attention: batch_size*1*2.contextemb_rnn_hdim question_passage_attention = question_passage_attention.expand( question_passage_attention.size(0), passage_question_attention.size(1), question_passage_attention.size(2)) #question_passage_attention: batch_size*num_words_passage*2.contextemb_rnn_hdim #passage_question_attention: batch_size*num_words_passage*2.contextemb_rnn_hdim #Skipping masking for passage_question_attention, question_passage_attention question_aware_passage_representation = torch.cat((passage_embedding,\ passage_question_attention,\ passage_embedding*passage_question_attention,\ passage_embedding*question_passage_attention), dim = -1) #question_aware_passage_representation: batch_size*num_words_passage*8.contextemb_rnn_hdim if self.use_dropout: question_aware_passage_representation = self.dropout_layer( question_aware_passage_representation) question_aware_passage_context_representation_one, _ = self.modeling_layer_one( question_aware_passage_representation, passage_lengths) #question_aware_passage_context_representation_one: batch_size*num_words_passage*2.modelinglayer_rnn_hdim if self.use_dropout: question_aware_passage_context_representation_one = self.dropout_layer( question_aware_passage_context_representation_one) question_aware_passage_context_representation_two, _ = self.modeling_layer_two( question_aware_passage_context_representation_one, passage_lengths) #question_aware_passage_context_representation_two: batch_size*num_words_passage*2.modelinglayer_rnn_hdim if self.use_dropout: question_aware_passage_context_representation_two = self.dropout_layer( question_aware_passage_context_representation_two) start_index_representation = torch.cat((question_aware_passage_representation,\ question_aware_passage_context_representation_one), dim = -1) end_index_representation = torch.cat((question_aware_passage_representation,\ question_aware_passage_context_representation_two), dim = -1) #start_index_representation: batch_size*num_words_passage*(8.contextemb_rnn_hdim + 2.modelinglayer_rnn_hdim) #end_index_representation: batch_size*num_words_passage*(8.contextemb_rnn_hdim + 2.modelinglayer_rnn_hdim) start_index_values = self.start_index_linear( start_index_representation).squeeze() end_index_values = self.end_index_linear( end_index_representation).squeeze() '''#start_index_probabilities: batch_size*num_words_passage #end_index_probabilities: batch_size*num_words_passage start_index_probabilities = utils.masked_softmax(start_index_values, passage_mask.float()) end_index_probabilities = utils.masked_softmax(end_index_values, passage_mask.float())''' start_index_log_probabilities = start_index_values + passage_mask.float( ).log() end_index_log_probabilities = end_index_values + passage_mask.float( ).log() start_index_log_probabilities = functional.log_softmax( start_index_log_probabilities) end_index_log_probabilities = functional.log_softmax( end_index_log_probabilities) #start_index_log_probabilities: batch_size*num_words_passage #end_index_log_probabilities: batch_size*num_words_passage #uncomment the below line in compare mode #return start_index_log_probabilities, end_index_log_probabilities, question_attention_weights return start_index_log_probabilities, end_index_log_probabilities
h_t, c_t = self.rnn(input, (h_t, c_t)) #h_t: batch_size*hidden_dim if reverse: h_t = utils.mask_sequence(h_t, hidden_mask) c_t = utils.mask_sequence(c_t, hidden_mask) elif self.rnn_type == 'GRU': h_t = self.rnn(input, h_t) #h_t: batch_size*hidden_dim if reverse: h_t = utils.mask_sequence(h_t, hidden_mask) elif self.rnn_type == 'RNN': pass sequence1_sequence2_matchattn[ith_item] = h_t if not self.training: sequence2_matchattn_weights[ith_item] = sequence2_attention_weights sequence1_sequence2_matchattn = sequence1_sequence2_matchattn.permute(1,0,2) if not self.training: sequence2_matchattn_weights = sequence2_matchattn_weights.permute(1,0,2) sequence1_mask = sequence1_mask.permute(1,0) #sequence1_mask: batch_size*num_words1 sequence1_sequence2_matchattn = utils.mask_sequence(sequence1_sequence2_matchattn, sequence1_mask) if not self.training: sequence2_matchattn_weights = utils.mask_sequence(sequence2_matchattn_weights, sequence1_mask) #sequence1_sequence2_matchattn: batch_size*num_words1*hidden_dim #sequence2_matchattn_weights: batch_size*num_words1*num_words2 return sequence1_sequence2_matchattn, sequence2_matchattn_weights if __name__=='__main__': matchattn = MatchAttention({'similarity_function': 'WeightedSumProjection', 'sequence1_dim':12, 'sequence2_dim':8, 'projection_dim':10, 'rnn_type':'LSTM', 'rnn_hdim':15, 'gated_attention':True}) sequence1_sequence2_matchattn, sequence2_matchattn_weights =\ matchattn(Variable(torch.randn(3,6,12)), Variable(torch.randn(3,4,8)),\ Variable(utils.sequence_mask(torch.LongTensor([5,3,6]))), Variable(utils.sequence_mask(torch.LongTensor([4,3,2]))), reverse=True) print(sequence1_sequence2_matchattn) print(sequence2_matchattn_weights)
def train(): cur_dir = os.getcwd() input_dir = 'input' glove_dir = '../SQUAD/glove/' glove_filename = 'glove.6B.50d.txt' glove_embdim = 50 glove_filepath = os.path.join(glove_dir, glove_filename) train_batch_size = 2 eval_batch_size = 1 file_names = [('MSRVTT/captions.json', 'MSRVTT/trainvideo.json', 'MSRVTT/Frames')] files = [[os.path.join(cur_dir, input_dir, filetype) for filetype in file] for file in file_names] train_dataloader, vocab, glove, train_data_size = loader.get_train_data( files, glove_filepath, glove_embdim, train_batch_size) file_names = [('MSRVTT/captions.json', 'MSRVTT/valvideo.json', 'MSRVTT/Frames')] files = [[os.path.join(cur_dir, input_dir, filetype) for filetype in file] for file in file_names] val_dataloader = loader.get_val_data(files, vocab, glove, eval_batch_size) save_dir = 'models/baseline/' save_dir_path = os.path.join(cur_dir, save_dir) if not os.path.exists(save_dir_path): os.makedirs(save_dir_path) glovefile = open(os.path.join(save_dir, 'glove.pkl'), 'wb') pickle.dump(glove, glovefile) glovefile.close() vocabfile = open(os.path.join(save_dir, 'vocab.pkl'), 'wb') pickle.dump(vocab, vocabfile) vocabfile.close() pretrained_wordvecs = glove.index2vec dict_args = { "intermediate_layers": ['layer4', 'fc'], "word_embeddings": pretrained_wordvecs, "pretrained_embdim": glove_embdim, "vocabulary_size": len(pretrained_wordvecs), "decoder_rnn_hidden_dim": glove_embdim, "decoder_tie_weights": True, "decoder_rnn_type": 'LSTM', "pretrained_feature_size": 1000 } csal = CSAL(dict_args) num_epochs = 4 learning_rate = 1.0 criterion = nn.NLLLoss(reduce=False) optimizer = optim.Adadelta(filter(lambda p: p.requires_grad, csal.parameters()), lr=learning_rate, rho=0.9, eps=1e-06, weight_decay=0) if USE_CUDA: csal = csal.cuda() criterion = criterion.cuda() for epoch in range(num_epochs): for i, batch in enumerate(train_dataloader): #######Load Data padded_imageframes_batch = Variable(torch.stack( batch[0])) #batch_size*num_frames*3*224*224 frame_sequence_lengths = Variable(torch.LongTensor( batch[1])) #batch_size padded_inputwords_batch = Variable(torch.LongTensor( batch[2])) #batch_size*num_words input_sequence_lengths = Variable(torch.LongTensor( batch[3])) #batch_size padded_outputwords_batch = Variable(torch.LongTensor( batch[4])) #batch_size*num_words output_sequence_lengths = Variable(torch.LongTensor( batch[5])) #batch_size video_ids_list = batch[6] captionwords_mask = Variable( utils.sequence_mask( output_sequence_lengths)) #batch_size*num_words if USE_CUDA: padded_imageframes_batch = padded_imageframes_batch.cuda() frame_sequence_lengths = frame_sequence_lengths.cuda() padded_inputwords_batch = padded_inputwords_batch.cuda() input_sequence_lengths = input_sequence_lengths.cuda() padded_outputwords_batch = padded_outputwords_batch.cuda() output_sequence_lengths = output_sequence_lengths.cuda() captionwords_mask = captionwords_mask.cuda() #######Forward csal = csal.train() optimizer.zero_grad() outputword_log_probabilities = csal(padded_imageframes_batch, frame_sequence_lengths, \ padded_inputwords_batch, input_sequence_lengths) #######Calculate Loss outputword_log_probabilities = outputword_log_probabilities.permute( 0, 2, 1) #outputword_log_probabilities: batch_size*vocab_size*num_words #padded_outputwords_batch: batch_size*num_words losses = criterion(outputword_log_probabilities, padded_outputwords_batch) #loss: batch_size*num_words losses = losses * captionwords_mask.float() loss = losses.sum() #######Backward loss.backward() optimizer.step() if ((i + 1) % 2 == 0): print('Epoch: [{0}/{1}], Step: [{2}/{3}], Loss: {4}'.format( \ epoch+1, num_epochs, i+1, train_data_size//train_batch_size, loss.data[0])) break if (epoch % 1 == 0): #After how many epochs #Get Validation Loss to stop overriding val_loss, bleu = evaluator.evaluate(val_dataloader, csal) #print("val_loss") #Early Stopping not required filename = 'csal' + '.pth' file = open(os.path.join(save_dir, filename), 'wb') torch.save( { 'state_dict': csal.state_dict(), 'dict_args': dict_args }, file) print('Saving the model to {}'.format(save_dir)) file.close()
def forward(self, videoframes, videoframes_lengths, inputwords, captionwords_lengths): #videoframes : batch_size*num_frames*3*224*224 #videoframes_lengths : batch_size #inputwords : batch_size*num_words #outputwords : batch_size*num_words #captionwords_lengths : batch_size videoframes = videoframes[:, 0:videoframes_lengths.data.max()].contiguous( ) videoframes_mask = Variable(utils.sequence_mask(videoframes_lengths)) #videoframes_mask: batch_size*num_frames #redundant because we are masking the loss #captionwords_mask = Variable(utils.sequence_mask(captionwords_lengths)) #captionwords_mask: batch_size*num_words #batch_size, num_frames, rgb, height, width = videoframes.size() #videoframes = videoframes.view(-1,rgb,height,width).contiguous() #videoframes : batch_size.num_frames*3*224*224 #videoframefeatures = self.pretrained_vision_layer(videoframes) #videoframefeatures_fc = videoframefeatures[1] #videoframefeatures_fc : batch_size.num_frames*1000 videoframes = videoframes.contiguous() batch_size, num_frames, num_features, H, W = videoframes.size() if H == 1: videoframefeatures_fc = videoframes.view( -1, num_features).contiguous() videoframefeatures_fc = self.vision_feature_dimred_layer( videoframefeatures_fc) #videoframefeatures_fc : batch_size.num_frames*rnn_hdim _, feature_dim = videoframefeatures_fc.size() videoframefeatures_fc = videoframefeatures_fc.view( batch_size, num_frames, feature_dim) #videoframefeatures_fc : batch_size*num_frames*1000 videoframefeatures_fc = utils.mask_sequence( videoframefeatures_fc, videoframes_mask) else: videoframefeatures_fc = videoframes inputword_vectors = self.pretrained_words_layer(inputwords) #inputword_vectors: batch_size*num_words*wembed_dim if not self.training: return self.sentence_decoder_layer.inference( self.frame_encoder_layer, videoframefeatures_fc, videoframes_lengths, self.pretrained_words_layer) #outputword_values = self.sentence_decoder_layer(inputword_vectors, videoframefeatures_fcmeanpooling, captionwords_mask) outputword_log_probabilities = self.sentence_decoder_layer( inputword_vectors, self.frame_encoder_layer, videoframefeatures_fc, videoframes_lengths) #outputword_values = batch_size*num_words*vocab_size # outputword_log_probabilities = functional.log_softmax(outputword_values, dim=2) #outputword_values = batch_size*num_words*vocab_size #outputword_log_probabilities = utils.mask_sequence(outputword_log_probabilities, captionwords_mask) return outputword_log_probabilities #batch_size*num_words*vocab_size
def forward(self, isequence, ilenghts, spatialqueryvector = None, temporalqueryvector = None): #isequence : batch_size*num_frames*C*H*W #ilenghts : batch_size #spatialqueryvector Attention over the frames of the video #temporalqueryvector Attention over hidden states of tracking LSTM contextvector, temporalvectors, spatialvectors = None, None, None isequence = self.dropout_layer(isequence) batch_size, num_frames, channel_dim, height, width = isequence.size() isequence = isequence.view(batch_size, num_frames, channel_dim, -1) #isequence : batch_size*num_frames*channel_dim*num_blocks isequence = isequence.permute(1, 0, 3, 2) #isequence : num_frames*batch_size*num_blocks*channel_dim num_blocks = isequence.size(2) if self.use_linear: isequence = functional.relu(self.linear(isequence)) #Dropout? if self.temporalattention: temporalvectors = Variable(isequence.data.new(num_frames, batch_size, self.context_dim).zero_()) if self.recurrent: h_t = spatialqueryvector if self.rnn_type == 'LSTM': c_t = self.init_hidden(batch_size) for step in range(num_frames): spatialvectors = isequence[step] #fsequence: batch_size*num_blocks*channel_dim slenghts = ilenghts.data.new(batch_size).zero_() + num_blocks smask = Variable(utils.sequence_mask(Variable(slenghts))) spatattnvector, spatattnweights = self.spatial_attention_function(spatialvectors, h_t, smask) #print(spatattnweights) if self.rnn_type == 'LSTM': h_t, c_t = self.trackrnn(spatattnvector, (h_t, c_t)) #h_t: batch_size*hidden_dim elif self.rnn_type == 'GRU': h_t = self.trackrnn(spatattnvector, h_t) #h_t: batch_size*hidden_dim elif self.rnn_type == 'RNN': pass if self.temporalattention: temporalvectors[step] = h_t contextvector = h_t else: for step in range(num_frames): spatialvectors = isequence[step] #fsequence: batch_size*num_blocks*channel_dim slenghts = ilenghts.data.new(batch_size).zero_() + num_blocks smask = Variable(utils.sequence_mask(Variable(slenghts))) spatattnvector, spatattnweights = self.spatial_attention_function(spatialvectors, spatialqueryvector, smask) #print(spatattnweights) temporalvectors[step] = spatattnvector contextvector = None #Should undergo temporal attention if self.temporalattention: imask = Variable(utils.sequence_mask(ilenghts)) temporalvectors = temporalvectors.permute(1, 0, 2) #temporalvectors: batch_size*num_frames*context_dim tempattnvector, tempattnweights = self.temporal_attention_function(temporalvectors, temporalqueryvector, imask) contextvector = tempattnvector #print(tempattnweights) return contextvector
def forward(self, passage, passage_lengths, passagechars, passagechar_lengths,\ question, question_lengths, questionchars, questionchar_lengths): #passage: batch_size*num_words_passage*wembdim #passage_lengths: batch_size #passagechars: batch_size*num_words_passage*num_max_chars #passagechar_lengths: batch_size*num_words_passage #question: batch_size*num_words_question*wembdim #question_lengths: batch_size #questionchars: batch_size*num_words_question*num_max_chars #questionchar_lengths: batch_size*num_words_question passage_embedding = passage[:, 0:passage_lengths.max()] question_embedding = question[:, 0:question_lengths.max()] passage_mask = Variable(utils.sequence_mask( passage_lengths)) #passage_mask: batch_size*num_words_passage question_mask = Variable(utils.sequence_mask( question_lengths)) #question_mask: batch_size*num_words_question ##### Character Embedding Layer if self.use_charemb: passage_char_embedding = self.charemb_layer( passagechars, passagechar_lengths) #passage_char_embedding: batch_size*num_words_passage*2.charemb_rnn_hdim question_char_embedding = self.charemb_layer( questionchars, questionchar_lengths) #question_char_embedding: batch_size*num_words_question*2.charemb_rnn_hdim ##### Char and Word Embedding Concatenation passage_embedding = torch.cat((passage, passage_char_embedding), dim=2) question_embedding = torch.cat((question, question_char_embedding), dim=2) passage_embedding = passage_embedding[:, 0:passage_lengths.max()] question_embedding = question_embedding[:, 0:question_lengths.max()] passage_embedding = utils.mask_sequence(passage_embedding, passage_mask) question_embedding = utils.mask_sequence(question_embedding, question_mask) #passage_embedding: batch_size*num_words_passage*(2.charemb_rnn_hdim+wembdim) #question_embedding: batch_size*num_words_question*(2.charemb_rnn_hdim+wembdim) ##### Context Embedding Layer #passage_embedding, passage_lengths = self.contextemb_layer(passage_embedding, passage_lengths) #question_embedding, question_lengths = self.contextemb_layer(question_embedding, question_lengths) passage_embedding, _ = self.contextemb_layer(passage_embedding, passage_lengths) question_embedding, _ = self.contextemb_layer(question_embedding, question_lengths) #passage_embedding: batch_size*num_words_passage*2.contextemb_rnn_hdim #question_embedding: batch_size*num_words_question*2.contextemb_rnn_hdim #Skipping recomputation of passage_mask and question_mask passage_question_matchattn_forward, question_matchattn_weights_forward = \ self.gated_attention_layer_forward(passage_embedding, question_embedding, passage_mask, question_mask) #passage_question_matchattn_forward: batch_size*num_words_passage*2.contextemb_rnn_hdim #question_matchattn_weights_forward: batch_size*num_words_passage*num_words_question passage_question_matchattn = passage_question_matchattn_forward if self.use_bidirectional: passage_question_matchattn_reverse, question_matchattn_weights_reverse = \ self.gated_attention_layer_backward(utils.reverse_sequence(passage_embedding), question_embedding,\ utils.reverse_sequence(passage_mask), question_mask, reverse=True) passage_question_matchattn_reverse = utils.reverse_sequence( passage_question_matchattn_reverse) if not self.training: question_matchattn_weights_reverse = utils.reverse_sequence( question_matchattn_weights_reverse) #passage_question_matchattn_reverse: batch_size*num_words_passage*2.contextemb_rnn_hdim passage_question_matchattn = torch.cat( (passage_question_matchattn_forward, passage_question_matchattn_reverse), dim=-1) #passage_question_matchattn: batch_size*num_words_passage*4.contextemb_rnn_hdim question_aware_passage_representation = passage_question_matchattn if self.use_selfmatching: passage_passage_selfattn, passage_selfattn_weights = \ self.self_matching_layer(passage_question_matchattn, passage_mask) #passage_passage_selfattn: batch_size*num_words_passage*2.contextemb_rnn_hdim #passage_selfattn_weights: batch_size*num_words_passage*num_words_passage #Skipping masking for passage_passage_selfattn question_aware_passage_representation = torch.cat( (passage_question_matchattn, passage_passage_selfattn), dim=-1) #if self.use_dropout: question_aware_passage_representation = self.dropout_layer(question_aware_passage_representation) if self.gated_selfmatching: question_aware_passage_representation = self.selfmatchinggate( question_aware_passage_representation, question_aware_passage_representation) question_aware_passage_representation, _ = self.modeling_layer( question_aware_passage_representation, passage_lengths) #question_aware_passage_representation: batch_size*num_words_passage*2.self.modelinglayer_rnn_hdim if self.use_dropout: question_aware_passage_representation = self.dropout_layer( question_aware_passage_representation) self.question_query_vector_expand = self.question_query_vector.unsqueeze( 0).expand(question_embedding.size(0), self.question_query_vector.size(0)) #self.question_query_vector_expand: batch_size*2.self.contextemb_rnn_hdim (question_query_vector_dim = 2.self.contextemb_rnn_hdim) queryvector_question_attention, question_attention_weights = \ self.question_attention_layer(question_embedding, self.question_query_vector_expand, question_mask) #queryvector_question_attention: batch_size*2.contextemb_rnn_hdim #question_attention_weights: batch_size*num_words_question pointer_passage_attention_values = \ self.pointer_network_layer(question_aware_passage_representation, queryvector_question_attention, passage_mask, 2) #pointer_passage_attention_probabilities: batch_size*2*num_words_passage start_index_values, end_index_values = torch.chunk( pointer_passage_attention_values, 2, 1) start_index_values = start_index_values.squeeze() end_index_values = end_index_values.squeeze() '''start_index_probabilities = utils.masked_softmax(start_index_values, passage_mask.float()) end_index_probabilities = utils.masked_softmax(end_index_values, passage_mask.float())''' start_index_log_probabilities = start_index_values + passage_mask.float( ).log() end_index_log_probabilities = end_index_values + passage_mask.float( ).log() start_index_log_probabilities = functional.log_softmax( start_index_log_probabilities) end_index_log_probabilities = functional.log_softmax( end_index_log_probabilities) #start_index_log_probabilities: batch_size*num_words_passage #end_index_log_probabilities: batch_size*num_words_passage #uncomment the below line in compare mode #return start_index_log_probabilities, end_index_log_probabilities, (question_matchattn_weights_forward + question_matchattn_weights_reverse)/2 return start_index_log_probabilities, end_index_log_probabilities