def train(): file_path_cap = os.path.join(Constants.data_folder_ann, Constants.captions_train_file) file_path_inst = os.path.join(Constants.data_folder_ann, Constants.instances_train_file) coco_dataloader_train, coco_data_train = get_dataloader( file_path_cap, file_path_inst, "train") file_path_cap = os.path.join(Constants.data_folder_ann, Constants.captions_val_file) file_path_inst = os.path.join(Constants.data_folder_ann, Constants.instances_val_file) coco_dataloader_val, coco_data_val = get_dataloader( file_path_cap, file_path_inst, "val") step = 0 best_bleu4 = 0 # initilze model, loss, etc model = CNNtoRNN(coco_data_train.vocab) model = model.to(Constants.device) criterion = nn.CrossEntropyLoss( ignore_index=coco_data_train.vocab.stoi[Constants.PAD]) optimizer = optim.Adam(model.parameters(), lr=Hyper.learning_rate) ##################################################################### if Constants.load_model: step = load_checkpoint(model, optimizer) for i in range(Hyper.total_epochs): model.train() # Set model to training mode model.decoderRNN.train() model.encoderCNN.train() epoch = i + 1 epochs_since_improvement = 0 print(f"Epoch: {epoch}") if Constants.save_model: checkpoint = { "state_dict": model.state_dict(), "optimizer": optimizer.state_dict(), "step": step, } save_checkpoint(checkpoint) for _, (imgs, captions) in tqdm(enumerate(coco_dataloader_train), total=len(coco_dataloader_train), leave=False): imgs = imgs.to(Constants.device) captions = captions.to(Constants.device) outputs = model(imgs, captions[:-1]) # forward pass vocab_size = outputs.shape[2] outputs1 = outputs.reshape(-1, vocab_size) captions1 = captions.reshape(-1) loss = criterion(outputs1, captions1) optimizer.zero_grad() loss.backward(loss) optimizer.step() save_checkpoint_epoch(checkpoint, epoch) # One epoch's validation recent_bleu4 = validate(val_loader=coco_dataloader_val, model=model, criterion=criterion) # Check if there was an improvement is_best = recent_bleu4 > best_bleu4 best_bleu4 = max(recent_bleu4, best_bleu4) if not is_best: epochs_since_improvement += 1 print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement, )) else: epochs_since_improvement = 0
print( "Example 6 OUTPUT: " + " ".join(model.caption_image(test_img6.to(Constants.device), vocab)) ) """ model.train() def get_image(path, t_): temp1 = Image.open(path).convert("RGB") temp2 = np.array(temp1) temp3 = t_(temp2) temp4 = temp3.unsqueeze(0) if Hyper.is_grayscale: img = T.cat((temp4, temp4, temp4), 1) return img img = temp4 return img if __name__ == "__main__": with open(Constants.vocab_file, 'rb') as f: vocab = pickle.load(f) print('Vocabulary successfully loaded from the vocab.pkl file') epoch = 2 model = CNNtoRNN(vocab) model = model.to(Constants.device) optimizer = optim.Adam(model.parameters(), lr=Hyper.learning_rate) ##################################################################### _ = load_checkpoint_epoch(model, optimizer, epoch) print_examples(model, vocab)