def __save_models(self, gen: Generator, disc: Discriminator, optim_gen: th.optim.Adam, optim_disc: th.optim.Adam): # Save discriminator th.save(disc.state_dict(), join(self.__output_dir, f"disc_{self.__curr_save}.pt")) th.save(optim_disc.state_dict(), join(self.__output_dir, f"optim_disc_{self.__curr_save}.pt")) # save generator th.save(gen.state_dict(), join(self.__output_dir, f"gen_{self.__curr_save}.pt")) th.save(optim_gen.state_dict(), join(self.__output_dir, f"optim_gen_{self.__curr_save}.pt"))
def trainModel(model, trainData, validData, optimizer: torch.optim.Adam): print(model) start_time = time.time() def trainEpoch(epoch): trainData.shuffle() total_loss, total, total_num_correct = 0, 0, 0 report_loss, report_total, report_num_correct = 0, 0, 0 for i in range(len(trainData)): (batch_docs, batch_docs_len, doc_mask), (batch_querys, batch_querys_len, query_mask), batch_answers, candidates = trainData[i] model.zero_grad() pred_answers, answer_probs = model(batch_docs, batch_docs_len, doc_mask, batch_querys, batch_querys_len, query_mask, answers=batch_answers, candidates=candidates) loss, num_correct = loss_func(batch_answers, pred_answers, answer_probs) loss.backward() for parameter in model.parameters(): parameter.grad.data.clamp_(-5.0, 5.0) # update the parameters optimizer.step() total_in_minibatch = batch_answers.size(0) report_loss += loss.data[0] * total_in_minibatch report_num_correct += num_correct report_total += total_in_minibatch total_loss += loss.data[0] * total_in_minibatch total_num_correct += num_correct total += total_in_minibatch if i % opt.log_interval == 0: print( "Epoch %2d, %5d/%5d; avg loss: %.2f; acc: %6.2f; %6.0f s elapsed" % (epoch, i + 1, len(trainData), report_loss / report_total, report_num_correct / report_total * 100, time.time() - start_time)) report_loss = report_total = report_num_correct = 0 del loss, pred_answers, answer_probs return total_loss / total, total_num_correct / total for epoch in range(opt.start_epoch, opt.epochs + 1): print('') # (1) train for one epoch on the training set train_loss, train_acc = trainEpoch(epoch) print('Epoch %d:\t average loss: %.2f\t train accuracy: %g' % (epoch, train_loss, train_acc * 100)) # (2) evaluate on the validation set valid_loss, valid_acc = eval(model, validData) print('=' * 20) print('Evaluating on validation set:') print('Validation loss: %.2f' % valid_loss) print('Validation accuracy: %g' % (valid_acc * 100)) print('=' * 20) model_state_dict = model.state_dict() optimizer_state_dict = optimizer.state_dict() # (4) drop a checkpoint checkpoint = { 'model': model_state_dict, 'epoch': epoch, 'optimizer': optimizer_state_dict, 'opt': opt, } torch.save( checkpoint, 'models/%s_epoch%d_acc_%.2f.pt' % (opt.save_model, epoch, 100 * valid_acc))