def main(args): num_frames = 15 ms_per_frame = 40 network = EncoderDecoder(args).cuda() optimizer = torch.optim.Adam(network.parameters(), lr=args.lr, betas=(0.9, 0.99)) criterion = nn.MSELoss() train_loader, dev_loader, test_loader = fetch_kth_data(args) # test_tens = next(iter(train_loader))['instance'][0, :, :, :, :].transpose(0, 1) # print(test_tens.shape) # save_image(test_tens, './img/test_tens.png') # print(next(iter(train_loader))['instance'][0, :, 0, :, :].shape) train_loss = [] dev_loss = [] for epoch in range(args.epochs): epoch_loss = 0 batch_num = 0 for item in train_loader: #label = item['label'] item = item['instance'].cuda() frames_processed = 0 batch_loss = 0 # fit a whole batch for all the different milliseconds for i in range(num_frames-1): for j in range(i+1, num_frames): network.zero_grad() frame_diff = j - i time_delta = torch.tensor(frame_diff * ms_per_frame).float().repeat(args.batch_size).cuda() time_delta.requires_grad = True seq = item[:, :, i, :, :] #print(seq.shape) # downsample #seq = F.interpolate(seq, size=(64, 64)) #print(seq.shape) seq.requires_grad = True seq_targ = item[:, :, j, :, :] # downsample #seq_targ = F.interpolate(seq_targ, size=(64, 64)) seq_targ.requires_grad = False assert seq.requires_grad and time_delta.requires_grad, 'No Gradients' outputs = network(seq, time_delta) error = criterion(outputs, seq_targ) error.backward() optimizer.step() batch_loss += error.cpu().item() frames_processed += 1 if i == 0: save_image(outputs, '/scratch/eecs-share/dinkinst/kth/img/train_output_{}_epoch_{}.png'.format(j, epoch)) batch_num += 1 epoch_loss += batch_loss print('Epoch {} Batch #{} Total Error {}'.format(epoch, batch_num, batch_loss)) print('\nEpoch {} Total Loss {} Scaled Loss {}\n'.format(epoch, epoch_loss, epoch_loss/frames_processed)) train_loss.append(epoch_loss) if epoch % 10 == 0: torch.save(network.state_dict(), KTH_PATH+str('/model_new_{}.pth'.format(epoch))) torch.save(optimizer.state_dict(), KTH_PATH+str('/optim_new_{}.pth'.format(epoch))) dev_loss.append(eval_model(network, dev_loader, epoch)) network.train() plt.plot(range(args.epochs), train_loss) plt.grid() plt.savefig('/scratch/eecs-share/dinkinst/kth/img/loss_train.png', dpi=64) plt.close('all') plt.plot(range(args.epochs), dev_loss) plt.grid() plt.savefig('/scratch/eecs-share/dinkinst/kth/img/loss_dev.png', dpi=64) plt.close('all')
def main(): torch.manual_seed(10) # fix seed for reproducibility torch.cuda.manual_seed(10) train_data, train_source_text, train_target_text = create_data( os.path.join(train_data_dir, train_dataset), lang) #dev_data, dev_source_text, dev_target_text = create_data(os.path.join(eval_data_dir, 'newstest2012_2013'), lang) eval_data, eval_source_text, eval_target_text = create_data( os.path.join(dev_data_dir, eval_dataset), lang) en_emb_lookup_matrix = train_source_text.vocab.vectors.to(device) target_emb_lookup_matrix = train_target_text.vocab.vectors.to(device) global en_vocab_size global target_vocab_size en_vocab_size = train_source_text.vocab.vectors.size(0) target_vocab_size = train_target_text.vocab.vectors.size(0) if verbose: print('English vocab size: ', en_vocab_size) print(lang, 'vocab size: ', target_vocab_size) print_runtime_metric('Vocabs loaded') model = EncoderDecoder(en_emb_lookup_matrix, target_emb_lookup_matrix, hidden_size, bidirectional, attention, attention_type, decoder_cell_type).to(device) model.encoder.device = device criterion = nn.CrossEntropyLoss( ignore_index=1 ) # ignore_index=1 comes from the target_data generation from the data iterator #optimiser = torch.optim.Adadelta(model.parameters(), lr=1.0, rho=0.9, eps=1e-06, weight_decay=0) # This is the exact optimiser in the paper; rho=0.95 optimiser = torch.optim.Adam(model.parameters(), lr=lr) best_loss = 10e+10 # dummy variable best_bleu = 0 epoch = 1 # initial epoch id if resume: print('\n ---------> Resuming training <----------') checkpoint_path = os.path.join(save_dir, 'checkpoint.pth') checkpoint = torch.load(checkpoint_path) epoch = checkpoint['epoch'] subepoch, num_subepochs = checkpoint['subepoch_num'] model.load_state_dict(checkpoint['state_dict']) best_loss = checkpoint['best_loss'] optimiser.load_state_dict(checkpoint['optimiser']) is_best = checkpoint['is_best'] metric_store.load(os.path.join(save_dir, 'checkpoint_metrics.pickle')) if subepoch == num_subepochs: epoch += 1 subepoch = 1 else: subepoch += 1 if verbose: print_runtime_metric('Model initialised') while epoch <= num_epochs: is_best = False # best loss or not # Initialise the iterators train_iter = BatchIterator(train_data, batch_size, do_train=True, seed=epoch**2) num_subepochs = train_iter.num_batches // subepoch_size # train sub-epochs from start_batch # This allows subepoch training resumption if not resume: subepoch = 1 while subepoch <= num_subepochs: if verbose: print(' Running code on: ', device) print('------> Training epoch {}, sub-epoch {}/{} <------'. format(epoch, subepoch, num_subepochs)) mean_train_loss = train(model, criterion, optimiser, train_iter, train_source_text, train_target_text, subepoch, num_subepochs) if verbose: print_runtime_metric('Training sub-epoch complete') print( '------> Evaluating sub-epoch {} <------'.format(subepoch)) eval_iter = BatchIterator(eval_data, batch_size, do_train=False, seed=325632) mean_eval_loss, mean_eval_bleu, _, mean_eval_sent_bleu, _, _ = evaluate( model, criterion, eval_iter, eval_source_text.vocab, eval_target_text.vocab, train_source_text.vocab, train_target_text.vocab) # here should be the eval data if verbose: print_runtime_metric('Evaluating sub-epoch complete') if mean_eval_loss < best_loss: best_loss = mean_eval_loss is_best = True if mean_eval_bleu > best_bleu: best_bleu = mean_eval_bleu is_best = True config_dict = { 'train_dataset': train_dataset, 'b_size': batch_size, 'h_size': hidden_size, 'bidirectional': bidirectional, 'attention': attention, 'attention_type': attention_type, 'decoder_cell_type': decoder_cell_type } # Save the model and the optimiser state for resumption (after each epoch) checkpoint = { 'epoch': epoch, 'subepoch_num': (subepoch, num_subepochs), 'state_dict': model.state_dict(), 'config': config_dict, 'best_loss': best_loss, 'best_BLEU': best_bleu, 'optimiser': optimiser.state_dict(), 'is_best': is_best } torch.save(checkpoint, os.path.join(save_dir, 'checkpoint.pth')) metric_store.log(mean_train_loss, mean_eval_loss) metric_store.save( os.path.join(save_dir, 'checkpoint_metrics.pickle')) if verbose: print('Checkpoint.') # Save the best model so far if is_best: save_dict = { 'state_dict': model.state_dict(), 'config': config_dict, 'epoch': epoch } torch.save(save_dict, os.path.join(save_dir, 'best_model.pth')) metric_store.save( os.path.join(save_dir, 'best_model_metrics.pickle')) if verbose: if is_best: print('Best model saved!') print( 'Ep {} Sub-ep {}/{} Tr loss {} Eval loss {} Eval BLEU {} Eval sent BLEU {}' .format(epoch, subepoch, num_subepochs, round(mean_train_loss, 3), round(mean_eval_loss, 3), round(mean_eval_bleu, 4), round(mean_eval_sent_bleu, 4))) subepoch += 1 epoch += 1