def main(args): data = DataLoader(pca=args.PCA, norm=args.norm) train_captions, train_feature, train_url, train_len = data.get_Training_data( args.training) test_captions, test_feature, test_url, test_len = data.get_val_data( args.testing) f, c, _ = data.eval_data() writer = SummaryWriter() encoder = Encoder(input_size=train_feature.shape[1], hidden_size=args.hidden_size) \ .to(device) decoder = Decoder(embed_size=args.embed_size, hidden_size=args.hidden_size, attention_dim=args.attention_size, vocab_size=len(data.word_to_idx)) \ .to(device) if args.load_weight: load_weights(encoder, args.model_path + "Jul28_10-04-57encoder") load_weights(decoder, args.model_path + "Jul28_10-04-57decoder") for epoch in range(args.num_epochs): params = list(decoder.parameters()) + list(encoder.parameters()) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(params=params, lr=args.learning_rate) # if epoch >= 100: training_loss = step(encoder=encoder, decoder=decoder, criterion=criterion, data=(train_captions, train_feature, train_len), optimizer=optimizer) # if epoch + 1 % 5 == 0: # a = evaluate(encoder, decoder, train_feature[0:2], train_captions[0:2], 5, data.word_to_idx) # print("bleu4 ", a) with torch.no_grad(): test_loss = step(encoder=encoder, decoder=decoder, criterion=criterion, data=(test_captions, test_feature, test_len)) # if epoch > 1: b1, b2, b3, b4 = evaluate(encoder, decoder, f, c, 5, data.word_to_idx, data.idx_to_word) writer.add_scalars('BLEU', { 'BLEU1': b1, 'BLEU2': b2, 'BLEU3': b3, 'BLEU4': b4 }, epoch + 1) if (epoch % 30) == 0: save_weights(encoder, args.model_path + "encoder" + str(epoch)) save_weights(decoder, args.model_path + "decoder" + str(epoch)) writer.add_scalars('loss', { 'train': training_loss, 'val': test_loss }, epoch + 1) print( 'Epoch [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}, TestLoss: {:.4f}, TestPerplexity: {:5.4f}' .format(epoch + 1, args.num_epochs, training_loss, np.exp(training_loss), test_loss, np.exp(test_loss))) args.learning_rate *= 0.995 if args.save_weight: save_weights(encoder, args.model_path + "encoder" + str(epoch)) save_weights(decoder, args.model_path + "decoder" + str(epoch)) if args.save_weight: save_weights(encoder, args.model_path + "encoder") save_weights(decoder, args.model_path + "decoder") if args.predict: sample = Sample(encoder=encoder, decoder=decoder, device=device) train_mask = [ random.randint(0, train_captions.shape[0] - 1) for _ in range(args.numOfpredection) ] test_mask = [ random.randint(0, test_captions.shape[0] - 1) for _ in range(args.numOfpredection) ] train_featur = torch.from_numpy(train_feature[train_mask]) train_featur = train_featur.to(device) train_encoder_out = encoder(train_featur) test_featur = torch.from_numpy(test_feature[test_mask]) test_featur = test_featur.to(device) test_encoder_out = encoder(test_featur) train_output = [] test_output = [] for i in range(len(test_mask)): print(i) pre = sample.caption_image_beam_search( train_encoder_out[i].reshape(1, args.embed_size), data.word_to_idx, 2) train_output.append(pre) pre = sample.caption_image_beam_search( test_encoder_out[i].reshape(1, args.embed_size), data.word_to_idx, 50) test_output.append(pre) print_output(output=test_output, sample=0, gt=test_captions[test_mask], img=test_url[test_mask], title="val", show_image=args.show_image, idx_to_word=data.idx_to_word) print("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX") print("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX") print("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX") print("") print_output(output=train_output, sample=0, gt=train_captions[train_mask], img=train_url[train_mask], title="traning", show_image=args.show_image, idx_to_word=data.idx_to_word)