コード例 #1
0
ファイル: test.py プロジェクト: gohyojun15/CSRNet-pytorch
def cal_mae(img_root, model_param_path):
    '''
    Calculate the MAE of the test data.
    img_root: the root of test image data.
    gt_dmap_root: the root of test ground truth density-map data.
    model_param_path: the path of specific mcnn parameters.
    '''
    device = torch.device("cuda")
    model = CSRNet()
    model.load_state_dict(torch.load(model_param_path))
    model.to(device)
    dataset = create_test_dataloader(img_root)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=False)
    model.eval()
    mae = 0
    with torch.no_grad():
        for i, data in enumerate(tqdm(dataloader)):
            image = data['image'].cuda()
            gt_densitymap = data['densitymap'].cuda()
            # forward propagation
            et_dmap = model(image)
            mae += abs(et_dmap.data.sum() - gt_densitymap.data.sum()).item()
            del image, gt_densitymap, et_dmap
    print("model_param_path:" + model_param_path + " mae:" +
          str(mae / len(dataloader)))
コード例 #2
0
def cal_mae(img_root, gt_dmap_root, model_param_path):
    '''
    Calculate the MAE of the test data.
    img_root: the root of test image data.
    gt_dmap_root: the root of test ground truth density-map data.
    model_param_path: the path of specific mcnn parameters.
    '''
    model = CSRNet()
    model.load_state_dict(torch.load(model_param_path,
                                     map_location=cfg.device))
    model.to(cfg.device)
    test_dataloader = create_test_dataloader(cfg.dataset_root)  # dataloader
    model.eval()
    sum_mae = 0
    with torch.no_grad():
        for i, data in enumerate(tqdm(test_dataloader)):
            image = data['image'].to(cfg.device)
            gt_densitymap = data['densitymap'].to(cfg.device)
            # forward propagation
            et_densitymap = model(image).detach()
            mae = abs(et_densitymap.data.sum() - gt_densitymap.data.sum())
            sum_mae += mae.item()
            # clear mem
            del i, data, image, gt_densitymap, et_densitymap
            torch.cuda.empty_cache()

    print("model_param_path:" + model_param_path + " mae:" +
          str(sum_mae / len(test_dataloader)))
コード例 #3
0
def estimate_density_map(img_root, gt_dmap_root, model_param_path, index):
    '''
    Show one estimated density-map.
    img_root: the root of test image data.
    gt_dmap_root: the root of test ground truth density-map data.
    model_param_path: the path of specific mcnn parameters.
    index: the order of the test image in test dataset.
    '''
    image_export_folder = 'export_images'
    model = CSRNet()
    model.load_state_dict(torch.load(model_param_path,
                                     map_location=cfg.device))
    model.to(cfg.device)
    test_dataloader = create_test_dataloader(cfg.dataset_root)  # dataloader
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(tqdm(test_dataloader)):
            image = data['image'].to(cfg.device)
            gt_densitymap = data['densitymap'].to(cfg.device)
            # forward propagation
            et_densitymap = model(image).detach()
            pred_count = et_densitymap.data.sum().cpu()
            actual_count = gt_densitymap.data.sum().cpu()
            et_densitymap = et_densitymap.squeeze(0).squeeze(0).cpu().numpy()
            gt_densitymap = gt_densitymap.squeeze(0).squeeze(0).cpu().numpy()
            image = image[0].cpu()  # denormalize(image[0].cpu())
            print(et_densitymap.shape)
            # et is the estimated density
            plt.imshow(et_densitymap, cmap=CM.jet)
            plt.savefig("{}/{}_{}_{}_{}".format(image_export_folder,
                                                str(i).zfill(3),
                                                str(int(pred_count)),
                                                str(int(actual_count)),
                                                'etdm.png'))
            # gt is the ground truth density
            plt.imshow(gt_densitymap, cmap=CM.jet)
            plt.savefig("{}/{}_{}_{}_{}".format(image_export_folder,
                                                str(i).zfill(3),
                                                str(int(pred_count)),
                                                str(int(actual_count)),
                                                'gtdm.png'))
            # image
            plt.imshow(image.permute(1, 2, 0))
            plt.savefig("{}/{}_{}_{}_{}".format(image_export_folder,
                                                str(i).zfill(3),
                                                str(int(pred_count)),
                                                str(int(actual_count)),
                                                'image.png'))

            # clear mem
            del i, data, image, et_densitymap, gt_densitymap, pred_count, actual_count
            torch.cuda.empty_cache()
コード例 #4
0
--test_image_root /home/gohyojun/바탕화면/Anthroprocene/Dataset/crane
--test_image_gt_root /home/gohyojun/바탕화면/Anthroprocene/Dataset/crane_labeled
--test_image_density_root /home/gohyojun/바탕화면/Anthroprocene/Dataset/density_map
"""


if __name__=="__main__":
    # argument parsing.
    args = parser.parse_args()
    cfg = Config(args)                                                          # configuration
    model = CSRNet().to(cfg.device)                                         # model
    criterion = nn.MSELoss(size_average=False)                              # objective
    optimizer = torch.optim.Adam(model.parameters(),lr=cfg.lr)              # optimizer

    train_dataloader = create_train_dataloader(cfg.train_dataset_root, use_flip=True, batch_size=cfg.batch_size)
    test_dataloader  = create_test_dataloader(cfg.test_dataset_root)             # dataloader

    min_mae = sys.maxsize
    min_mae_epoch = -1
    for epoch in range(1, cfg.epochs):                          # start training
        model.train()
        epoch_loss = 0.0
        for i, data in enumerate(tqdm(train_dataloader)):
            image = data['image'].to(cfg.device)
            gt_densitymap = data['densitymap'].to(cfg.device) * 16# todo 1/4 rescale effect때문에
            et_densitymap = model(image)                        # forward propagation
            loss = criterion(et_densitymap,gt_densitymap)       # calculate loss
            epoch_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()                                     # back propagation
            optimizer.step()                                    # update network parameters
コード例 #5
0
def main():
    # ========= dataloaders ===========
    train_dataloader = create_train_dataloader(root=args.datapath,
                                               batch_size=args.batch_size)
    test_dataloader = create_val_dataloader(root=args.datapath,
                                            batch_size=args.batch_size)
    # train_dataloader, test_dataloader = create_CK_dataloader(batch_size=args.batch_size)
    start_epoch = 0
    # ======== models & loss ==========
    mini_xception = Mini_Xception()
    loss = nn.CrossEntropyLoss()
    # ========= load weights ===========
    if args.resume or args.evaluate:
        checkpoint = torch.load(args.pretrained, map_location=device)
        mini_xception.load_state_dict(checkpoint['mini_xception'],
                                      strict=False)
        start_epoch = checkpoint['epoch'] + 1
        print(f'\tLoaded checkpoint from {args.pretrained}\n')
        time.sleep(1)
    else:
        print(
            "******************* Start training from scratch *******************\n"
        )
        time.sleep(2)

    if args.evaluate:
        if args.mode == 'test':
            test_dataloader = create_test_dataloader(
                root=args.datapath, batch_size=args.batch_size)
        elif args.mode == 'val':
            test_dataloader = create_val_dataloader(root=args.datapath,
                                                    batch_size=args.batch_size)
        else:
            test_dataloader = create_train_dataloader(
                root=args.datapath, batch_size=args.batch_size)

        validate(mini_xception, loss, test_dataloader, 0)
        return

    # =========== optimizer ===========
    # parameters = mini_xception.named_parameters()
    # for name, p in parameters:
    #     print(p.requires_grad, name)
    # return
    optimizer = torch.optim.Adam(mini_xception.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', patience=args.lr_patience, verbose=True)
    # ========================================================================
    for epoch in range(start_epoch, args.epochs):
        # =========== train / validate ===========
        train_loss = train_one_epoch(mini_xception, loss, optimizer,
                                     train_dataloader, epoch)
        val_loss, accuracy, percision, recall = validate(
            mini_xception, loss, test_dataloader, epoch)
        scheduler.step(val_loss)
        val_loss, accuracy, percision, recall = round(val_loss, 3), round(
            accuracy, 3), round(percision, 3), round(recall, 3)
        logging.info(f"\ttraining epoch={epoch} .. train_loss={train_loss}")
        logging.info(f"\tvalidation epoch={epoch} .. val_loss={val_loss}")
        logging.info(
            f'\tAccuracy = {accuracy*100} % .. Percision = {percision*100} % .. Recall = {recall*100} % \n'
        )
        time.sleep(2)
        # ============= tensorboard =============
        writer.add_scalar('train_loss', train_loss, epoch)
        writer.add_scalar('val_loss', val_loss, epoch)
        writer.add_scalar('percision', percision, epoch)
        writer.add_scalar('recall', recall, epoch)
        writer.add_scalar('accuracy', accuracy, epoch)
        # ============== save model =============
        if epoch % args.savefreq == 0:
            checkpoint_state = {
                'mini_xception': mini_xception.state_dict(),
                "epoch": epoch
            }
            savepath = os.path.join(args.savepath,
                                    f'weights_epoch_{epoch}.pth.tar')
            torch.save(checkpoint_state, savepath)
            print(f'\n\t*** Saved checkpoint in {savepath} ***\n')
            time.sleep(2)
    writer.close()