def eval(data_loader, model, epoch_train, batch_train, optimizer, args):
    model.eval()
    corrects, avg_loss, accumulated_loss, size = 0, 0, 0, 0
    predicates_all, target_all = [], []
    for i_batch, (data) in enumerate(data_loader):
        inputs, target = data

        # scale target to be zero indexed
        if args.target_sub_scaler != 0:
            target.sub_(args.target_sub_scaler)

        size += len(target)
        if args.cuda:
            inputs, target = inputs.cuda(), target.cuda()

        inputs = Variable(inputs, volatile=True)
        target = Variable(target)
        logit = model(inputs)
        predicates = torch.max(logit, 1)[1].view(target.size()).data
        accumulated_loss += F.nll_loss(logit, target,
                                       size_average=False).item()
        corrects += (torch.max(logit, 1)[1].view(
            target.size()).data == target.data).sum()
        predicates_all += predicates.cpu().numpy().tolist()
        target_all += target.data.cpu().numpy().tolist()
        if args.cuda:
            torch.cuda.synchronize()

    avg_loss = accumulated_loss / size
    accuracy = 100.0 * corrects / size
    model.train()
    print('\nEvaluation - loss: {:.6f}  lr: {:.5f}  acc: {:.3f}% ({}/{}) '.
          format(avg_loss,
                 optimizer.state_dict()['param_groups'][0]['lr'], accuracy,
                 corrects, size))
    print_f_score(predicates_all, target_all)
    print('\n')
    if args.log_result:
        with open(os.path.join(args.save_folder, 'result.csv'), 'a') as r:
            r.write('\n{:d},{:d},{:.5f},{:.2f},{:f}'.format(
                epoch_train, batch_train, avg_loss, accuracy,
                optimizer.state_dict()['param_groups'][0]['lr']))

    return avg_loss, accuracy
示例#2
0
        with open(filename, 'wb') as fo:
            pickle.dump(result, fo)

    save_result(result)

    plt.plot(train_acces)
    plt.plot(val_acces)
    plt.title('model accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'val'], loc='lower right')
    plt.savefig('./mult_acc.png', format='png', dpi=80)
    plt.show()

    plt.plot(train_losses)
    plt.plot(val_losses)
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'val'], loc='upper right')
    plt.savefig('./mult_loss.png', format='png', dpi=80)
    plt.show()

    model_path = './model/COVID19_multbest.path.tar'
    predicates_all, target_all = test(cnn, model_path, test_loader)
    print_f_score(predicates_all, target_all)
    print(accuracy_score(target_all, predicates_all))
    print(classification_report(target_all, predicates_all))
    print("AUC", roc_auc_score(target_all, predicates_all))
    print(confusion_matrix(target_all, predicates_all))