def load_model(model_dir, en_emb_lookup_matrix, target_emb_lookup_matrix): save_dict = torch.load(os.path.join(os.path.dirname(cwd), model_dir)) config = save_dict['config'] print(' Model config: \n', config) model = EncoderDecoder(en_emb_lookup_matrix, target_emb_lookup_matrix, config['h_size'], config['bidirectional'], config['attention'], config['attention_type'], config['decoder_cell_type']).to(device) mn.hidden_size = config['h_size'] model.encoder.device = device model.load_state_dict(save_dict['state_dict']) return model
def main(): torch.manual_seed(10) # fix seed for reproducibility torch.cuda.manual_seed(10) train_data, train_source_text, train_target_text = create_data( os.path.join(train_data_dir, train_dataset), lang) #dev_data, dev_source_text, dev_target_text = create_data(os.path.join(eval_data_dir, 'newstest2012_2013'), lang) eval_data, eval_source_text, eval_target_text = create_data( os.path.join(dev_data_dir, eval_dataset), lang) en_emb_lookup_matrix = train_source_text.vocab.vectors.to(device) target_emb_lookup_matrix = train_target_text.vocab.vectors.to(device) global en_vocab_size global target_vocab_size en_vocab_size = train_source_text.vocab.vectors.size(0) target_vocab_size = train_target_text.vocab.vectors.size(0) if verbose: print('English vocab size: ', en_vocab_size) print(lang, 'vocab size: ', target_vocab_size) print_runtime_metric('Vocabs loaded') model = EncoderDecoder(en_emb_lookup_matrix, target_emb_lookup_matrix, hidden_size, bidirectional, attention, attention_type, decoder_cell_type).to(device) model.encoder.device = device criterion = nn.CrossEntropyLoss( ignore_index=1 ) # ignore_index=1 comes from the target_data generation from the data iterator #optimiser = torch.optim.Adadelta(model.parameters(), lr=1.0, rho=0.9, eps=1e-06, weight_decay=0) # This is the exact optimiser in the paper; rho=0.95 optimiser = torch.optim.Adam(model.parameters(), lr=lr) best_loss = 10e+10 # dummy variable best_bleu = 0 epoch = 1 # initial epoch id if resume: print('\n ---------> Resuming training <----------') checkpoint_path = os.path.join(save_dir, 'checkpoint.pth') checkpoint = torch.load(checkpoint_path) epoch = checkpoint['epoch'] subepoch, num_subepochs = checkpoint['subepoch_num'] model.load_state_dict(checkpoint['state_dict']) best_loss = checkpoint['best_loss'] optimiser.load_state_dict(checkpoint['optimiser']) is_best = checkpoint['is_best'] metric_store.load(os.path.join(save_dir, 'checkpoint_metrics.pickle')) if subepoch == num_subepochs: epoch += 1 subepoch = 1 else: subepoch += 1 if verbose: print_runtime_metric('Model initialised') while epoch <= num_epochs: is_best = False # best loss or not # Initialise the iterators train_iter = BatchIterator(train_data, batch_size, do_train=True, seed=epoch**2) num_subepochs = train_iter.num_batches // subepoch_size # train sub-epochs from start_batch # This allows subepoch training resumption if not resume: subepoch = 1 while subepoch <= num_subepochs: if verbose: print(' Running code on: ', device) print('------> Training epoch {}, sub-epoch {}/{} <------'. format(epoch, subepoch, num_subepochs)) mean_train_loss = train(model, criterion, optimiser, train_iter, train_source_text, train_target_text, subepoch, num_subepochs) if verbose: print_runtime_metric('Training sub-epoch complete') print( '------> Evaluating sub-epoch {} <------'.format(subepoch)) eval_iter = BatchIterator(eval_data, batch_size, do_train=False, seed=325632) mean_eval_loss, mean_eval_bleu, _, mean_eval_sent_bleu, _, _ = evaluate( model, criterion, eval_iter, eval_source_text.vocab, eval_target_text.vocab, train_source_text.vocab, train_target_text.vocab) # here should be the eval data if verbose: print_runtime_metric('Evaluating sub-epoch complete') if mean_eval_loss < best_loss: best_loss = mean_eval_loss is_best = True if mean_eval_bleu > best_bleu: best_bleu = mean_eval_bleu is_best = True config_dict = { 'train_dataset': train_dataset, 'b_size': batch_size, 'h_size': hidden_size, 'bidirectional': bidirectional, 'attention': attention, 'attention_type': attention_type, 'decoder_cell_type': decoder_cell_type } # Save the model and the optimiser state for resumption (after each epoch) checkpoint = { 'epoch': epoch, 'subepoch_num': (subepoch, num_subepochs), 'state_dict': model.state_dict(), 'config': config_dict, 'best_loss': best_loss, 'best_BLEU': best_bleu, 'optimiser': optimiser.state_dict(), 'is_best': is_best } torch.save(checkpoint, os.path.join(save_dir, 'checkpoint.pth')) metric_store.log(mean_train_loss, mean_eval_loss) metric_store.save( os.path.join(save_dir, 'checkpoint_metrics.pickle')) if verbose: print('Checkpoint.') # Save the best model so far if is_best: save_dict = { 'state_dict': model.state_dict(), 'config': config_dict, 'epoch': epoch } torch.save(save_dict, os.path.join(save_dir, 'best_model.pth')) metric_store.save( os.path.join(save_dir, 'best_model_metrics.pickle')) if verbose: if is_best: print('Best model saved!') print( 'Ep {} Sub-ep {}/{} Tr loss {} Eval loss {} Eval BLEU {} Eval sent BLEU {}' .format(epoch, subepoch, num_subepochs, round(mean_train_loss, 3), round(mean_eval_loss, 3), round(mean_eval_bleu, 4), round(mean_eval_sent_bleu, 4))) subepoch += 1 epoch += 1
def test(noise_type): global test_dataset if noise_type == NoiseDataloader.GAUSSIAN: test_dataset = NoiseDataloader(dataset_type=NoiseDataloader.TEST, noisy_per_image=1, noise_type=NoiseDataloader.GAUSSIAN) elif noise_type == NoiseDataloader.TEXT_OVERLAY: test_dataset = NoiseDataloader(dataset_type=NoiseDataloader.TEST, noisy_per_image=1, noise_type=NoiseDataloader.TEXT_OVERLAY) elif noise_type == NoiseDataloader.SALT_PEPPER: test_dataset = NoiseDataloader(dataset_type=NoiseDataloader.TEST, noisy_per_image=1, noise_type=NoiseDataloader.SALT_PEPPER) else: return # Initializing network network = EncoderDecoder() network = nn.DataParallel(network) instance = '010' pretrained_model_folder_path = os.path.join(pp.trained_models_folder_path, 'Instance_' + instance) for pretrained_model_file_name in os.listdir(pretrained_model_folder_path): try: if pretrained_model_file_name.endswith('.pt'): network.load_state_dict( torch.load(os.path.join(pretrained_model_folder_path, pretrained_model_file_name))) print('Network weights initialized using file from:', pretrained_model_file_name) else: continue except: print('Unable to load network with weights from:', pretrained_model_file_name) continue idx = random.randint(0, len(test_dataset)) noisy_image, clean_image = test_dataset[idx] predicted_image = network(torch.unsqueeze(torch.as_tensor(noisy_image), dim=0))[0] clean_image = NoiseDataloader.convert_model_output_to_image(clean_image) noisy_image = NoiseDataloader.convert_model_output_to_image(noisy_image) predicted_image = NoiseDataloader.convert_model_output_to_image(predicted_image) plt.figure(num='Network Performance using weights at {}'.format(pretrained_model_file_name), figsize=(20, 20)) plt.subplot(2, 2, 1) plt.imshow(clean_image, cmap='gray') plt.colorbar() plt.title('Original Image') plt.subplot(2, 2, 2) plt.imshow(noisy_image, cmap='gray') plt.colorbar() plt.title('Noisy Image') plt.subplot(2, 2, 3) plt.imshow(predicted_image, cmap='gray') plt.colorbar() plt.title('Predicted Image') plt.subplot(2, 2, 4) plt.imshow(np.sqrt(np.sum((clean_image - predicted_image) ** 2, axis=2)), cmap='gray') plt.title('Euclidean Distance') plt.colorbar() plt.show()