def train(): saved_model_path = None use_multiple_caption = False captions_per_image = 5 save_steps = 10000 n_epoch = 1000 learning_rate = 5e-5 scheduler_step_size = 15 batch_size = 32 vocab = Vocabulary.load(vocab_path) vocab_size = (vocab.idx // 100 + 1) * 100 # if use_multiple_caption: # dataset = data.COCODemoDataset2(split='train', vocab=vocab, mode='train', image_size=224, # captions_per_image=captions_per_image) # collate_fn = data.collate_fn_2 # else: # dataset = data.COCODemoDataset1(split='train', vocab=vocab, mode='train', image_size=224) # collate_fn = data.collate_fn dataset = data_video.MSVDDatasetCHN(vocab=vocab, segment_method=segment_method, split='train', feature=feature_type) collate_fn = data_video.collate_fn data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=0) lm_config = model_cnn.LanguageModelConfig() lm_config.image_feature_dim = 2048 lm_config.vocab_size = vocab_size lm_config.use_attention = False lm_cnn = model_cnn.LanguageModelConv(lm_config) lm_cnn.to(device) lm_cnn.train(True) epoch = 0 global_step = 0 optimizer = torch.optim.RMSprop(lm_cnn.parameters(), lr=learning_rate) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step_size, gamma=.1) criterion = nn.CrossEntropyLoss() while epoch < n_epoch: scheduler.step() print('1 epoch = {} steps'.format(len(data_loader))) for _, (image_filename_list, image_feature, captions, lengths) in enumerate(data_loader): global_step += 1 mask = torch.zeros(captions.shape) for i in range(len(captions)): mask[i][:len(captions[i])] = 1 batch_size, max_caption_len = captions.shape # TODO: caption = ['<S>', 'hello', 'world'], no ending symbol ? captions = captions.to(device) # (batch_size, caption_length) image_feature = image_feature.to(device) # image feature # image_feature, image_feature_fc7 = image_cnn.forward(images) # (batch_size, feature_dim) # if use_multiple_caption: # image_feature, image_feature_fc7 = repeat_features(image_feature, image_feature_fc7, captions_per_image) # word output word_output, attn = lm_cnn.forward( None, image_feature, input_word=captions) # (batch_size, vocab_size, max_len) word_output = word_output[:, :, : -1] # (batch_size, vocab_size, max_len - 1) captions = captions[:, 1:].contiguous() mask = mask[:, 1:].contiguous() word_output_t = word_output.permute(0, 2, 1).contiguous().view( batch_size * (max_caption_len - 1), -1) captions_t = captions.view(batch_size * (max_caption_len - 1), 1) maskids = torch.nonzero(mask.view(-1)).numpy().reshape(-1) caption_loss = criterion( word_output_t[maskids, :], captions_t[maskids, :].view(maskids.shape[0])) optimizer.zero_grad() caption_loss.backward() optimizer.step() print(epoch, global_step, 'loss:', caption_loss) # if global_step % 10000 == 0: # save_model('../models_cnn/model-{}-ep{}'.format(global_step, epoch), (lm_cnn, optimizer, epoch, global_step)) if epoch % 100 == 0: test1(lm_cnn, global_step) epoch += 1
def train(): save_steps = 10000 n_epoch = 1000 learning_rate = 1e-4 scheduler_step_size = 15 batch_size = 32 vocab = Vocabulary.load(vocab_path) vocab_size = (vocab.idx // 100 + 1) * 100 dataset = data_video.MSVDDatasetCHN(vocab=vocab, segment_method=segment_method, split='train', feature=feature_type) collate_fn = data_video.collate_fn data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=0) lm_config = model_lstm.LanguageModelConfig() lm_config.vocab_size = vocab_size lm_config.image_feature_dim = 512 lm_lstm = model_lstm.LanguageModelLSTM(lm_config) lm_lstm.to(device) lm_lstm.train(True) epoch = 0 global_step = 0 optimizer = torch.optim.Adam(lm_lstm.parameters(), lr=learning_rate) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step_size, gamma=.1) criterion = nn.CrossEntropyLoss() while epoch < n_epoch: scheduler.step() print('1 epoch = {} steps'.format(len(data_loader))) for _, (image_filename_list, features, captions, lengths) in enumerate(data_loader): global_step += 1 features = features.to(device) captions = captions.to(device) word_prob_output, last_hidden_state = lm_lstm.forward( features, captions, lengths, ) # print(word_prob_output.shape) # (batch, seq_len, vocab_size) target = torch.nn.utils.rnn.pack_padded_sequence( captions, lengths=lengths, batch_first=True)[0] loss = criterion(word_prob_output, target) optimizer.zero_grad() loss.backward() optimizer.step() print('epoch {}, global step: {}, loss: {}'.format( epoch, global_step, loss)) if global_step % 10 == 0: print(data_video.get_res5c_feature_of_video.cache_info()) if global_step % save_steps == 0 and global_step > 0: test1(lm_lstm, global_step) # save_model('../models/model-lstm-{}'.format(global_step), (lm_lstm, optimizer, epoch, global_step)) epoch += 1
def test1(lm_cnn, global_step): annotation_file_name = '../all_results/results_cnn_chn_char_resnet/msvd_annotation_{}.json'.format( global_step) output_file_name = '../all_results/results_cnn_chn_char_resnet/msvd_result_{}.json'.format( global_step) if not os.path.exists(os.path.dirname(annotation_file_name)): os.makedirs(os.path.dirname(annotation_file_name)) beam_size = 3 data_loader_batch_size = 4 max_sentence_length = 15 vocab = Vocabulary.load(vocab_path) dataset = data_video.MSVDDatasetCHN(vocab=vocab, segment_method=segment_method, caption_mode='text', split='test', feature=feature_type) # dataset = data.DemoDataset() data_loader = DataLoader(dataset=dataset, batch_size=data_loader_batch_size, shuffle=False, num_workers=0) start_word_id = vocab.get_index(start_word) end_word_id = vocab.get_index(end_word) result_set = set() result_obj = [] annotation_obj = { 'info': 'N/A', 'licenses': 'N/A', 'type': 'captions', 'images': [], 'annotations': [] } caption_id = 0 for index, (image_filename_list, image_feature, captions) in enumerate(data_loader): # print('images', len(images)) # print('captions', len(captions)) # captions[0] = list of 5 captions batch_size = image_feature.shape[ 0] # actual batch size may be smaller than specified batch size !!! beam_searcher = BeamSearch(beam_size, batch_size, max_sentence_length) # print('image_feature', image_feature.shape) # print('image_feature_fc7', image_feature_fc7.shape) image_feature = image_feature.to(device) image_feature_fc7 = image_feature # b, d, h, w = image_feature.shape # image_feature = image_feature.unsqueeze(1).expand(b, beam_size, d, h, w) # image_feature = image_feature.contiguous().view(b * beam_size, d, h, w) b, d = image_feature_fc7.shape image_feature_fc7 = image_feature_fc7.unsqueeze(1).expand( b, beam_size, d) image_feature_fc7 = image_feature_fc7.contiguous().view( b * beam_size, d) wordclass_feed = np.zeros( (beam_size * batch_size, max_sentence_length), dtype='int64') wordclass_feed[:, 0] = start_word_id outcaps = np.empty((batch_size, 0)).tolist() for j in range(max_sentence_length - 1): wordclass = Variable(torch.from_numpy(wordclass_feed)).to(device) wordact, _ = lm_cnn.forward(None, image_feature_fc7, wordclass) wordact = wordact[:, :, :-1] wordact_j = wordact[..., j] beam_indices, wordclass_indices = beam_searcher.expand_beam( wordact_j) if len(beam_indices) == 0 or j == (max_sentence_length - 2): # Beam search is over. generated_captions = beam_searcher.get_results() for k in range(batch_size): g = generated_captions[:, k] outcaps[k] = [vocab.get_word(int(x.cpu())) for x in g] else: wordclass_feed = wordclass_feed[beam_indices] image_feature_fc7 = image_feature_fc7.index_select( 0, Variable(torch.LongTensor(beam_indices).to(device))) # image_feature = image_feature.index_select(0, Variable(torch.LongTensor(beam_indices).to(device))) for i, wordclass_idx in enumerate(wordclass_indices): wordclass_feed[i, j + 1] = wordclass_idx for j in range(batch_size): num_words = len(outcaps[j]) if end_word in outcaps[j]: num_words = outcaps[j].index(end_word) outcaps[j] = outcaps[j][:num_words] outcaps[j] = [ i for i in outcaps[j] if i != end_word and i != start_word ] outcap = ' '.join(outcaps[j][:num_words]) if image_filename_list[j] not in result_set: result = { 'image_id': image_filename_list[j], 'caption': ''.join(outcap.split()), 'image_filename': image_filename_list[j] } print(result) result_obj.append(result) result_set.add(image_filename_list[j]) annotation_obj['images'].append({'id': image_filename_list[j]}) caption = captions[j] annotation_obj['annotations'].append({ 'image_id': image_filename_list[j], 'caption': caption, 'id': caption_id }) caption_id += 1 # print('--------') f_output = open(output_file_name, 'w') json.dump(result_obj, f_output, indent=4) f_output.close() f_ann = open(annotation_file_name, 'w') json.dump(annotation_obj, f_ann, indent=4) f_ann.close()
def test1(lm_lstm, global_step): lm_lstm.train(False) annotation_file_name = '../all_results/results_lstm_chn_word_c3dpool5/msvd_annotation{}.json'.format( global_step) output_file_name = '../all_results/results_lstm_chn_word_c3dpool5/msvd_result{}.json'.format( global_step) if not os.path.exists(os.path.dirname(annotation_file_name)): os.makedirs(os.path.dirname(annotation_file_name)) length_normalization_factor = 0.0 beam_size = 3 max_sentence_length = 15 vocab = Vocabulary.load(vocab_path) dataset = data_video.MSVDDatasetCHN(vocab=vocab, segment_method=segment_method, caption_mode='text', split='test', feature=feature_type) collate_fn = data_video.collate_fn # data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=True, collate_fn=collate_fn, num_workers=4) data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) dataset_size = len(dataset) start_word_id = vocab.get_index(start_word) end_word_id = vocab.get_index(end_word) result_generator = COCOResultGenerator() for index, (image_filename_list, features, captions) in enumerate(data_loader): # print(index) # continue # write annotation image_id = image_filename_list[0] caption = captions[0] result_generator.add_annotation(image_id, caption) if result_generator.has_output(image_id): continue if len(result_generator.test_image_set) >= 1970: break # extract image feature features = features.to(device) image_embedding = lm_lstm.get_image_embedding(features) image_embedding = image_embedding.repeat([beam_size, 1]) inputs = image_embedding.unsqueeze_(0) states = None initial_beam = Caption(sentence=[start_word_id], state=states, logprob=0.0, score=0.0, metadata=[""]) partial_captions = TopN(beam_size) partial_captions.push(initial_beam) complete_captions = TopN(beam_size) output_softmax = nn.Softmax(dim=-1) for j in range(max_sentence_length): try: partial_captions_list = partial_captions.extract() partial_captions.reset() if j > 0: ii = torch.tensor( [c.sentence[-1] for c in partial_captions_list]) ii = ii.to(device) inputs = lm_lstm.get_word_embedding(ii) inputs = inputs.unsqueeze_(0) states = [None, None] states[0] = torch.cat( [c.state[0] for c in partial_captions_list], dim=1) # (1, 3, 512) states[1] = torch.cat( [c.state[1] for c in partial_captions_list], dim=1) # (1, 3, 512) hiddens, states = lm_lstm.lstm(inputs, states) outputs = lm_lstm.output_word_layer( hiddens.squeeze(0)) # lstm outputs: softmax = output_softmax(outputs) for (i, partial_caption) in enumerate(partial_captions_list): word_probabilities = softmax[i].detach().cpu().numpy( ) # cuda tensors -> cpu for sorting # state = (states[0][0][i].detach().cpu().numpy(), states[1][0][i].detach().cpu().numpy()) state = (states[0][:, i:i + 1], states[1][:, i:i + 1]) words_and_probs = list(enumerate(word_probabilities)) words_and_probs.sort(key=lambda x: -x[1]) words_and_probs = words_and_probs[0:beam_size] # print([(self.vocab.get_word(w), p) for w, p in words_and_probs]) for w, p in words_and_probs: if p < 1e-12: continue # Avoid log(0). sentence = partial_caption.sentence + [w] logprob = partial_caption.logprob + math.log(p) score = logprob metadata_list = None if w == end_word_id: if length_normalization_factor > 0: score /= len( sentence)**length_normalization_factor beam = Caption(sentence, state, logprob, score, metadata_list) complete_captions.push(beam) else: beam = Caption(sentence, state, logprob, score, metadata_list) partial_captions.push(beam) if partial_captions.size() == 0: break except Exception as e: exc_info = sys.exc_info() traceback.print_exception(*exc_info) IPython.embed() if not complete_captions.size(): complete_captions = partial_captions captions = complete_captions.extract(sort=True) print(len(result_generator.test_image_set)) print('{}, {}/{} {}'.format(image_id, index, dataset_size, result_generator.has_output(image_id))) for i, caption in enumerate(captions): sentence = [vocab.get_word(w) for w in caption.sentence] # print(sentence) sentence = [ w for w in sentence if (w != start_word and w != end_word) ] # ignore start and end tokens sentence = "".join(sentence) print(" %d) %s (p=%f)" % (i, sentence, math.exp(caption.logprob))) if i == 0: print(sentence) result_generator.add_output(image_id, sentence) annotation_obj, result_obj = result_generator.get_annotation_and_output() # print(annotation_obj) with open(annotation_file_name, 'w') as f: json.dump(annotation_obj, f) with open(output_file_name, 'w') as f: json.dump(result_obj, f) lm_lstm.train(True)