def main(args): num_frames = 15 ms_per_frame = 40 network = EncoderDecoder(args).cuda() optimizer = torch.optim.Adam(network.parameters(), lr=args.lr, betas=(0.9, 0.99)) criterion = nn.MSELoss() train_loader, dev_loader, test_loader = fetch_kth_data(args) # test_tens = next(iter(train_loader))['instance'][0, :, :, :, :].transpose(0, 1) # print(test_tens.shape) # save_image(test_tens, './img/test_tens.png') # print(next(iter(train_loader))['instance'][0, :, 0, :, :].shape) train_loss = [] dev_loss = [] for epoch in range(args.epochs): epoch_loss = 0 batch_num = 0 for item in train_loader: #label = item['label'] item = item['instance'].cuda() frames_processed = 0 batch_loss = 0 # fit a whole batch for all the different milliseconds for i in range(num_frames-1): for j in range(i+1, num_frames): network.zero_grad() frame_diff = j - i time_delta = torch.tensor(frame_diff * ms_per_frame).float().repeat(args.batch_size).cuda() time_delta.requires_grad = True seq = item[:, :, i, :, :] #print(seq.shape) # downsample #seq = F.interpolate(seq, size=(64, 64)) #print(seq.shape) seq.requires_grad = True seq_targ = item[:, :, j, :, :] # downsample #seq_targ = F.interpolate(seq_targ, size=(64, 64)) seq_targ.requires_grad = False assert seq.requires_grad and time_delta.requires_grad, 'No Gradients' outputs = network(seq, time_delta) error = criterion(outputs, seq_targ) error.backward() optimizer.step() batch_loss += error.cpu().item() frames_processed += 1 if i == 0: save_image(outputs, '/scratch/eecs-share/dinkinst/kth/img/train_output_{}_epoch_{}.png'.format(j, epoch)) batch_num += 1 epoch_loss += batch_loss print('Epoch {} Batch #{} Total Error {}'.format(epoch, batch_num, batch_loss)) print('\nEpoch {} Total Loss {} Scaled Loss {}\n'.format(epoch, epoch_loss, epoch_loss/frames_processed)) train_loss.append(epoch_loss) if epoch % 10 == 0: torch.save(network.state_dict(), KTH_PATH+str('/model_new_{}.pth'.format(epoch))) torch.save(optimizer.state_dict(), KTH_PATH+str('/optim_new_{}.pth'.format(epoch))) dev_loss.append(eval_model(network, dev_loader, epoch)) network.train() plt.plot(range(args.epochs), train_loss) plt.grid() plt.savefig('/scratch/eecs-share/dinkinst/kth/img/loss_train.png', dpi=64) plt.close('all') plt.plot(range(args.epochs), dev_loss) plt.grid() plt.savefig('/scratch/eecs-share/dinkinst/kth/img/loss_dev.png', dpi=64) plt.close('all')
def train(train_loader, test_loader, gradient_clipping=1, hidden_state_size=10, lr=0.001, epochs=100, classify=True): model = EncoderDecoder(input_size=28, hidden_size=hidden_state_size, output_size=28, labels_num=10) if not classify \ else EncoderDecoder(input_size=28, hidden_size=hidden_state_size, output_size=28, is_prediction=True, labels_num=10) model = model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=lr) loss_name = "mse" min_loss = float("inf") task_name = "classify" if classify else "reconstruct" validation_losses = [] validation_accuracies = [] tensorboard_writer = init_writer(results_path, lr, classify, hidden_state_size, epochs) for epoch in range(1, epochs): total_loss = 0 total_batches = 0 for batch_idx, (data, target) in enumerate(train_loader): data = data.to(device) target = target.to(device) # data_sequential = data # turn each image to vector sized 784 data_sequential = data.view(data.shape[0], 28, 28) optimizer.zero_grad() if classify: resconstucted_batch, batch_pred_probs = model(data_sequential) loss = model.loss(data_sequential, resconstucted_batch, target, batch_pred_probs) else: resconstucted_batch = model(data_sequential) loss = model.loss(data_sequential, resconstucted_batch) total_loss += loss.item() loss.backward() if gradient_clipping: nn.utils.clip_grad_norm_(model.parameters(), max_norm=gradient_clipping) optimizer.step() total_batches += 1 epoch_loss = total_loss / total_batches tensorboard_writer.add_scalar('train_loss', epoch_loss, epoch) print(f'Train Epoch: {epoch} \t loss: {epoch_loss}') validation_loss = validation(model, test_loader, validation_losses, device, classify, validation_accuracies, tensorboard_writer, epoch) model.train() if epoch % 5 == 0 or validation_loss < min_loss: file_name = f"ae_toy_{loss_name}_lr={lr}_hidden_size={hidden_state_size}_epoch={epoch}_gradient_clipping={gradient_clipping}.pt" path = os.path.join(results_path, "saved_models", "MNIST_task", task_name, file_name) torch.save(model, path) min_loss = min(validation_loss, min_loss) plot_validation_loss(epochs, gradient_clipping, lr, loss_name, validation_losses, hidden_state_size, task_name) if classify: plot_validation_acc(epochs, gradient_clipping, lr, loss_name, validation_accuracies, hidden_state_size, task_name)
def train(train_loader, validate_data, device, gradient_clipping=1, hidden_state_size=10, lr=0.001, opt="adam", epochs=1000, batch_size=32): model = EncoderDecoder(1, hidden_state_size, 1, 50).to(device) validate_data = validate_data.to(device) if (opt == "adam"): optimizer = torch.optim.Adam(model.parameters(), lr=lr) else: optimizer = torch.optim.RMSprop(model.parameters(), lr=lr) optimizer_name = 'adam' if 'adam' in str(optimizer).lower() else 'mse' mse = nn.MSELoss() min_loss = float("inf") best_loss_global = float("inf") min_in, min_out = None, None validation_losses = [] for epoch in range(0, epochs): total_loss = 0 for batch_idx, data in enumerate(train_loader): data = data.to(device) optimizer.zero_grad() output = model(data) loss = mse(output, data) total_loss += loss.item() if loss.item() < min_loss: min_loss = loss.item() min_in, min_out = data, output loss.backward() if gradient_clipping: nn.utils.clip_grad_norm_(model.parameters(), max_norm=gradient_clipping) optimizer.step() epoch_loss = total_loss / len(train_loader) best_loss_global = min(best_loss_global, epoch_loss) print(f'Train Epoch: {epoch} \t loss: {epoch_loss}') if epoch % 100 == 0: path = f'{results_path}saved_models/ae_toy_{optimizer_name}_lr={lr}_hidden_size={hidden_state_size}_' \ f'_gradient_clipping={gradient_clipping}_' create_folders(path) torch.save(model, path + f"/epoch={epoch}_bestloss={best_loss_global}.pt") # run validation if epoch % 20 == 0: model.eval() mse.eval() output = model(validate_data) loss = mse(output, validate_data) # print("Accuracy: {:.4f}".format(acc)) validation_losses.append(loss.item()) mse.train() model.train() plot_sequence_examples(epochs, gradient_clipping, lr, min_in, min_out, optimizer_name, batch_size) plot_validation_loss(epochs, gradient_clipping, lr, optimizer_name, validation_losses, batch_size)
def train(resume_training=True): EMBEDDING_SIZE = 32 num_hiddens, num_layers, dropout, batch_size, num_steps = EMBEDDING_SIZE, 2, 0.1, 64, 10 lr, num_epochs, device = 0.005, 1000, d2lt.try_gpu() ffn_num_input, ffn_num_hiddens, num_heads = EMBEDDING_SIZE, 64, 4 key_size, query_size, value_size = EMBEDDING_SIZE, EMBEDDING_SIZE, EMBEDDING_SIZE norm_shape = [EMBEDDING_SIZE] ### Load data data_iter, src_vocab, tgt_vocab = load_data_nmt(batch_size, num_steps) encoder = TransformerEncoder(len(src_vocab), key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout) decoder = TransformerDecoder(len(tgt_vocab), key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout) ### Load model model = EncoderDecoder(encoder, decoder).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=lr) ### Load checkpoint if resume_training and PATH_MODEL.exists( ) and os.path.getsize(PATH_MODEL) > 0: model, optimizer, last_epoch = load_checkpoint(model, optimizer) print("Continue training from last checkpoint...") else: if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) with open(PATH_MODEL, 'w') as fp: pass print( 'No prior checkpoint existed, created new save files for checkpoint.' ) model.apply(xavier_init_weights) last_epoch = 0 # model.apply(xavier_init_weights) # model.to(device) # optimizer = torch.optim.Adam(model.parameters(), lr=lr) ### Initialize Loss functions loss = MaskedSoftmaxCELoss() ### Train model.train() # animator = d2lt.Animator(xlabel='epoch', ylabel='loss', # xlim=[10, num_epochs]) for epoch in range(last_epoch, num_epochs): timer = d2lt.Timer() metric = d2lt.Accumulator(2) # Sum of training loss, no. of tokens for batch in data_iter: X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch] bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0], device=device).reshape(-1, 1) dec_input = torch.cat([bos, Y[:, :-1]], 1) # Teacher forcing Y_hat, _ = model(X, dec_input, X_valid_len) l = loss(Y_hat, Y, Y_valid_len) l.sum().backward() # Make the loss scalar for `backward` d2lt.grad_clipping(model, 1) num_tokens = Y_valid_len.sum() optimizer.step() with torch.no_grad(): metric.add(l.sum(), num_tokens) if (epoch + 1) % 10 == 0: # animator.add(epoch + 1, (metric[0] / metric[1],)) print(f'epoch {epoch + 1} - ' f'loss {metric[0] / metric[1]:.5f}') ### Save checkpoint save_checkpoint(epoch, model, optimizer) print(f'loss {metric[0] / metric[1]:.5f}, {metric[1] / timer.stop():.1f} ' f'tokens/sec on {str(device)}')