Пример #1
0
    def __save_models(self, gen: Generator, disc: Discriminator,
                      optim_gen: th.optim.Adam, optim_disc: th.optim.Adam):
        # Save discriminator
        th.save(disc.state_dict(),
                join(self.__output_dir, f"disc_{self.__curr_save}.pt"))

        th.save(optim_disc.state_dict(),
                join(self.__output_dir, f"optim_disc_{self.__curr_save}.pt"))

        # save generator
        th.save(gen.state_dict(),
                join(self.__output_dir, f"gen_{self.__curr_save}.pt"))

        th.save(optim_gen.state_dict(),
                join(self.__output_dir, f"optim_gen_{self.__curr_save}.pt"))
Пример #2
0
def trainModel(model, trainData, validData, optimizer: torch.optim.Adam):
    print(model)
    start_time = time.time()

    def trainEpoch(epoch):
        trainData.shuffle()

        total_loss, total, total_num_correct = 0, 0, 0
        report_loss, report_total, report_num_correct = 0, 0, 0
        for i in range(len(trainData)):
            (batch_docs, batch_docs_len,
             doc_mask), (batch_querys, batch_querys_len,
                         query_mask), batch_answers, candidates = trainData[i]

            model.zero_grad()
            pred_answers, answer_probs = model(batch_docs,
                                               batch_docs_len,
                                               doc_mask,
                                               batch_querys,
                                               batch_querys_len,
                                               query_mask,
                                               answers=batch_answers,
                                               candidates=candidates)

            loss, num_correct = loss_func(batch_answers, pred_answers,
                                          answer_probs)

            loss.backward()
            for parameter in model.parameters():
                parameter.grad.data.clamp_(-5.0, 5.0)
            # update the parameters
            optimizer.step()

            total_in_minibatch = batch_answers.size(0)

            report_loss += loss.data[0] * total_in_minibatch
            report_num_correct += num_correct
            report_total += total_in_minibatch

            total_loss += loss.data[0] * total_in_minibatch
            total_num_correct += num_correct
            total += total_in_minibatch
            if i % opt.log_interval == 0:
                print(
                    "Epoch %2d, %5d/%5d; avg loss: %.2f; acc: %6.2f;  %6.0f s elapsed"
                    % (epoch, i + 1, len(trainData), report_loss /
                       report_total, report_num_correct / report_total * 100,
                       time.time() - start_time))

                report_loss = report_total = report_num_correct = 0
            del loss, pred_answers, answer_probs

        return total_loss / total, total_num_correct / total

    for epoch in range(opt.start_epoch, opt.epochs + 1):
        print('')

        #  (1) train for one epoch on the training set
        train_loss, train_acc = trainEpoch(epoch)
        print('Epoch %d:\t average loss: %.2f\t train accuracy: %g' %
              (epoch, train_loss, train_acc * 100))

        #  (2) evaluate on the validation set
        valid_loss, valid_acc = eval(model, validData)
        print('=' * 20)
        print('Evaluating on validation set:')
        print('Validation loss: %.2f' % valid_loss)
        print('Validation accuracy: %g' % (valid_acc * 100))
        print('=' * 20)

        model_state_dict = model.state_dict()
        optimizer_state_dict = optimizer.state_dict()
        #  (4) drop a checkpoint
        checkpoint = {
            'model': model_state_dict,
            'epoch': epoch,
            'optimizer': optimizer_state_dict,
            'opt': opt,
        }
        torch.save(
            checkpoint, 'models/%s_epoch%d_acc_%.2f.pt' %
            (opt.save_model, epoch, 100 * valid_acc))