def main():
    args = parse_args()

    update_config(cfg, args)
    cfg.defrost()
    cfg.freeze()

    record_prefix = './eval2D_results_'
    if args.is_vis:
        result_dir = record_prefix + cfg.EXP_NAME
        mse2d_lst = np.loadtxt(os.path.join(result_dir,
                                            'mse2d_each_joint.txt'))
        PCK2d_lst = np.loadtxt(os.path.join(result_dir, 'PCK2d.txt'))

        plot_performance(PCK2d_lst[1, :], PCK2d_lst[0, :], mse2d_lst)
        exit()

    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model_path = args.model_path
    is_vis = args.is_vis

    # FP16 SETTING
    if cfg.FP16.ENABLED:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if cfg.FP16.STATIC_LOSS_SCALE != 1.0:
        if not cfg.FP16.ENABLED:
            print(
                "Warning:  if --fp16 is not used, static_loss_scale will be ignored."
            )

    model = eval(cfg.MODEL.NAME + '.get_pose_net')(cfg, is_train=False)

    # # calculate GFLOPS
    # dump_input = torch.rand(
    #     (5, 3, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[0])
    # )

    # print(get_model_summary(model, dump_input, verbose=cfg.VERBOSE))

    # ops, params = get_model_complexity_info(
    #    model, (3, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[0]),
    #    as_strings=True, print_per_layer_stat=True, verbose=True)
    # input()

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    if cfg.MODEL.SYNC_BN and not args.distributed:
        print(
            'Warning: Sync BatchNorm is only supported in distributed training.'
        )

    if args.gpu != -1:
        device = torch.device('cuda:' + str(args.gpu))
        torch.cuda.set_device(args.gpu)
    else:
        device = torch.device('cpu')
    # load model state
    if model_path:
        print("Loading model:", model_path)
        ckpt = torch.load(model_path)  #, map_location='cpu')
        if 'state_dict' not in ckpt.keys():
            state_dict = ckpt
        else:
            state_dict = ckpt['state_dict']
            print('Model epoch {}'.format(ckpt['epoch']))

        for key in list(state_dict.keys()):
            new_key = key.replace("module.", "")
            state_dict[new_key] = state_dict.pop(key)

        model.load_state_dict(state_dict, strict=True)

    model.to(device)

    # calculate GFLOPS
    dump_input = torch.rand(
        (1, 3, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[0])).to(device)

    print(get_model_summary(model, dump_input, verbose=cfg.VERBOSE))

    model.eval()

    # inference_dataset = eval('dataset.{}'.format(cfg.DATASET.TEST_DATASET[0].replace('_kpt','')))(
    #     cfg.DATA_DIR,
    #     cfg.DATASET.TEST_SET,
    #     transform=transform
    # )
    inference_dataset = eval('dataset.{}'.format(
        cfg.DATASET.TEST_DATASET[0].replace('_kpt', '')))(
            cfg.DATA_DIR,
            cfg.DATASET.TEST_SET,
            transforms=build_transforms(cfg, is_train=False))

    batch_size = args.batch_size
    data_loader = torch.utils.data.DataLoader(
        inference_dataset,
        batch_size=batch_size,  #48
        shuffle=False,
        num_workers=min(8, batch_size),  #8
        pin_memory=False)

    print('\nEvaluation loader information:\n' + str(data_loader.dataset))
    n_joints = cfg.DATASET.NUM_JOINTS
    th2d_lst = np.array([i for i in range(1, 50)])
    PCK2d_lst = np.zeros((len(th2d_lst), ))
    mse2d_lst = np.zeros((n_joints, ))
    visibility_lst = np.zeros((n_joints, ))

    print('Start evaluating... [Batch size: {}]\n'.format(
        data_loader.batch_size))
    with torch.no_grad():
        pose2d_mse_loss = JointsMSELoss().to(device)
        infer_time = [0, 0]
        start_time = time.time()
        for i, ret in enumerate(data_loader):
            # pose2d_gt: b x 21 x 2 is [u,v] 0<=u<64, 0<=v<64 (heatmap size)
            # visibility: b x 21 vis=0/1
            imgs = ret['imgs']
            pose2d_gt = ret['pose2d']  # b [x v] x 21 x 2
            visibility = ret['visibility']  # b [x v] x 21 x 1

            s1 = time.time()
            if 'CPM' == cfg.MODEL.NAME:
                pose2d_gt = pose2d_gt.view(-1, *pose2d_gt.shape[-2:])
                heatmap_lst = model(
                    imgs.to(device), ret['centermaps'].to(device)
                )  # 6 groups of heatmaps, each of which has size (1,22,32,32)
                heatmaps = heatmap_lst[-1][:, 1:]
                pose2d_pred = data_loader.dataset.get_kpts(heatmaps)
                hm_size = heatmap_lst[-1].shape[-1]  # 32
            else:
                if cfg.MODEL.NAME == 'pose_hrnet_transformer':
                    # imgs: b(1) x (4*seq_len) x 3 x 256 x 256
                    n_batches, seq_len = imgs.shape[0], imgs.shape[1] // 4
                    idx_lst = torch.tensor([4 * i for i in range(seq_len)])
                    imgs = torch.stack([
                        imgs[b, idx_lst + cam_idx] for b in range(n_batches)
                        for cam_idx in range(4)
                    ])  # (b*4) x seq_len x 3 x 256 x 256

                    pose2d_pred, heatmaps_pred, _ = model(
                        imgs.cuda(device))  # (b*4) x 21 x 2
                    pose2d_gt = pose2d_gt[:, 4 * (seq_len // 2):4 * (
                        seq_len // 2 + 1)].contiguous().view(
                            -1, *pose2d_pred.shape[-2:])  # (b*4) x 21 x 2
                    visibility = visibility[:, 4 * (seq_len // 2):4 * (
                        seq_len // 2 + 1)].contiguous().view(
                            -1, *visibility.shape[-2:])  # (b*4) x 21

                else:
                    if 'Aggr' in cfg.MODEL.NAME:
                        # imgs: b x (4*5) x 3 x 256 x 256
                        n_batches, seq_len = imgs.shape[0], len(
                            cfg.DATASET.SEQ_IDX)
                        true_batch_size = imgs.shape[1] // seq_len
                        pose2d_gt = torch.cat([
                            pose2d_gt[b, true_batch_size *
                                      (seq_len // 2):true_batch_size *
                                      (seq_len // 2 + 1)]
                            for b in range(n_batches)
                        ],
                                              dim=0)

                        visibility = torch.cat([
                            visibility[b, true_batch_size *
                                       (seq_len // 2):true_batch_size *
                                       (seq_len // 2 + 1)]
                            for b in range(n_batches)
                        ],
                                               dim=0)

                        imgs = torch.cat([
                            imgs[b, true_batch_size * j:true_batch_size *
                                 (j + 1)] for j in range(seq_len)
                            for b in range(n_batches)
                        ],
                                         dim=0)  # (b*4*5) x 3 x 256 x 256

                        heatmaps_pred, _ = model(imgs.to(device))
                    else:
                        pose2d_gt = pose2d_gt.view(-1, *pose2d_gt.shape[-2:])
                        heatmaps_pred, _ = model(
                            imgs.to(device))  # b x 21 x 64 x 64

                    pose2d_pred = get_final_preds(
                        heatmaps_pred, cfg.MODEL.HEATMAP_SOFTMAX)  # b x 21 x 2

                hm_size = heatmaps_pred.shape[-1]  # 64

            if i > 20:
                infer_time[0] += 1
                infer_time[1] += time.time() - s1

            # rescale to the original image before DLT

            if 'RHD' in cfg.DATASET.TEST_DATASET[0]:
                crop_size, corner = ret['crop_size'], ret['corner']
                crop_size, corner = crop_size.view(-1, 1, 1), corner.unsqueeze(
                    1)  # b x 1 x 1; b x 2 x 1
                pose2d_pred = pose2d_pred.cpu() * crop_size / hm_size + corner
                pose2d_gt = pose2d_gt * crop_size / hm_size + corner
            else:
                orig_width, orig_height = data_loader.dataset.orig_img_size
                pose2d_pred[:, :, 0] *= orig_width / hm_size
                pose2d_pred[:, :, 1] *= orig_height / hm_size
                pose2d_gt[:, :, 0] *= orig_width / hm_size
                pose2d_gt[:, :, 1] *= orig_height / hm_size

                # for k in range(21):
                #     print(pose2d_gt[0,k].tolist(), pose2d_pred[0,k].tolist())
                # input()
            # 2D errors
            pose2d_pred, pose2d_gt, visibility = pose2d_pred.cpu().numpy(
            ), pose2d_gt.numpy(), visibility.squeeze(2).numpy()

            # import matplotlib.pyplot as plt
            # imgs = cv2.resize(imgs[0].permute(1,2,0).cpu().numpy(), tuple(data_loader.dataset.orig_img_size))
            # for k in range(21):
            #     print(pose2d_gt[0,k],pose2d_pred[0,k],visibility[0,k])
            # for k in range(0,21,5):
            #     fig = plt.figure()
            #     ax1 = fig.add_subplot(131)
            #     ax2 = fig.add_subplot(132)
            #     ax3 = fig.add_subplot(133)
            #     ax1.imshow(cv2.cvtColor(imgs / imgs.max(), cv2.COLOR_BGR2RGB))
            #     plot_hand(ax1, pose2d_gt[0,:,0:2], order='uv')
            #     ax2.imshow(cv2.cvtColor(imgs / imgs.max(), cv2.COLOR_BGR2RGB))
            #     plot_hand(ax2, pose2d_pred[0,:,0:2], order='uv')
            #     ax3.imshow(heatmaps_pred[0,k].cpu().numpy())
            #     plt.show()
            mse_each_joint = np.linalg.norm(pose2d_pred - pose2d_gt,
                                            axis=2) * visibility  # b x 21

            mse2d_lst += mse_each_joint.sum(axis=0)
            visibility_lst += visibility.sum(axis=0)

            for th_idx in range(len(th2d_lst)):
                PCK2d_lst[th_idx] += np.sum(
                    (mse_each_joint < th2d_lst[th_idx]) * visibility)

            period = 10
            if i % (len(data_loader) // period) == 0:
                print("[Evaluation]{}% finished.".format(
                    period * i // (len(data_loader) // period)))
            #if i == 10:break
        print('Evaluation spent {:.2f} s\tfps: {:.1f} {:.4f}'.format(
            time.time() - start_time, infer_time[0] / infer_time[1],
            infer_time[1] / infer_time[0]))

        mse2d_lst /= visibility_lst
        PCK2d_lst /= visibility_lst.sum()

        result_dir = record_prefix + cfg.EXP_NAME
        if not os.path.exists(result_dir):
            os.mkdir(result_dir)

        mse_file, pck_file = os.path.join(
            result_dir,
            'mse2d_each_joint.txt'), os.path.join(result_dir, 'PCK2d.txt')
        print('Saving results to ' + mse_file)
        print('Saving results to ' + pck_file)
        np.savetxt(mse_file, mse2d_lst, fmt='%.4f')
        np.savetxt(pck_file, np.stack((th2d_lst, PCK2d_lst)))

        plot_performance(PCK2d_lst, th2d_lst, mse2d_lst)
Ejemplo n.º 2
0
def main():
    args = parse_args()
    update_config(cfg, args)
    check_config(cfg)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'valid')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg,
                                                               is_train=False)

    dump_input = torch.rand(
        (1, 3, cfg.DATASET.INPUT_SIZE, cfg.DATASET.INPUT_SIZE))
    logger.info(get_model_summary(model, dump_input, verbose=cfg.VERBOSE))

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    if cfg.TEST.MODEL_FILE:
        logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
        model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=True)
    else:
        model_state_file = os.path.join(final_output_dir, 'model_best.pth.tar')
        logger.info('=> loading model from {}'.format(model_state_file))
        model.load_state_dict(torch.load(model_state_file))

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
    model.eval()

    test_dataset = HIEDataset(DATA_PATH)
    data_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=0,
                                              pin_memory=False)

    if cfg.MODEL.NAME == 'pose_hourglass':
        transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
        ])
    else:
        transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                # mean=[0.485, 0.456, 0.406],
                # std=[0.229, 0.224, 0.225]
                mean=[0.5, 0.5, 0.5],
                std=[0.5, 0.5, 0.5])
        ])

    parser = HeatmapParser(cfg)
    all_preds = []
    all_scores = []

    pbar = tqdm(total=len(test_dataset)) if cfg.TEST.LOG_PROGRESS else None
    for i, images in enumerate(data_loader):
        # for i, (images, annos) in enumerate(data_loader):
        assert 1 == images.size(0), 'Test batch size should be 1'

        image = images[0].cpu().numpy()
        # size at scale 1.0
        if (i % 100 == 0):
            print("Start process images %d" % i)
        base_size, center, scale = get_multi_scale_size(
            image, cfg.DATASET.INPUT_SIZE, 1.0, min(cfg.TEST.SCALE_FACTOR))
        # print("Multi-scale end")

        with torch.no_grad():
            final_heatmaps = None
            tags_list = []
            for idx, s in enumerate(sorted(cfg.TEST.SCALE_FACTOR,
                                           reverse=True)):
                input_size = cfg.DATASET.INPUT_SIZE
                image_resized, center, scale = resize_align_multi_scale(
                    image, input_size, s, min(cfg.TEST.SCALE_FACTOR))
                image_resized = transforms(image_resized)
                image_resized = image_resized.unsqueeze(0).cuda()

                outputs, heatmaps, tags = get_multi_stage_outputs(
                    cfg, model, image_resized, cfg.TEST.FLIP_TEST,
                    cfg.TEST.PROJECT2IMAGE, base_size)

                final_heatmaps, tags_list = aggregate_results(
                    cfg, s, final_heatmaps, tags_list, heatmaps, tags)

            final_heatmaps = final_heatmaps / float(len(cfg.TEST.SCALE_FACTOR))
            tags = torch.cat(tags_list, dim=4)
            grouped, scores = parser.parse(final_heatmaps, tags,
                                           cfg.TEST.ADJUST, cfg.TEST.REFINE)

            final_results = get_final_preds(
                grouped, center, scale,
                [final_heatmaps.size(3),
                 final_heatmaps.size(2)])

        if cfg.TEST.LOG_PROGRESS:
            pbar.update()

        if i % cfg.PRINT_FREQ == 0:
            prefix = '{}_{}'.format(
                os.path.join(final_output_dir, 'result_valid'), i)
            # logger.info('=> write {}'.format(prefix))
            # save_valid_image(image, final_results, '{}.jpg'.format(prefix),         dataset=test_dataset.name)
            # save_valid_image(image, final_results, '{}.jpg'.format(prefix),dataset='HIE20')
            # save_debug_images(cfg, image_resized, None, None, outputs, prefix)

        all_preds.append(final_results)
        all_scores.append(scores)

    if cfg.TEST.LOG_PROGRESS:
        pbar.close()

    # save preds and scores as json
    test_dataset.save_json(all_preds, all_scores)
    print('Save finished!')
Ejemplo n.º 3
0
def main_worker(
        gpu, ngpus_per_node, args, final_output_dir, tb_log_dir
):
    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    if cfg.FP16.ENABLED:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if cfg.FP16.STATIC_LOSS_SCALE != 1.0:
        if not cfg.FP16.ENABLED:
            print("Warning:  if --fp16 is not used, static_loss_scale will be ignored.")

    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if cfg.MULTIPROCESSING_DISTRIBUTED:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        print('Init process group: dist_url: {}, world_size: {}, rank: {}'.
              format(args.dist_url, args.world_size, args.rank))
        dist.init_process_group(
            backend=cfg.DIST_BACKEND,
            init_method=args.dist_url,
            world_size=args.world_size,
            rank=args.rank
        )

    update_config(cfg, args)

    # setup logger
    logger, _ = setup_logger(final_output_dir, args.rank, 'train')

    model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
        cfg, is_train=True
    )

    # copy model file
    if not cfg.MULTIPROCESSING_DISTRIBUTED or (
            cfg.MULTIPROCESSING_DISTRIBUTED
            and args.rank % ngpus_per_node == 0
    ):
        this_dir = os.path.dirname(__file__)
        shutil.copy2(
            os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
            final_output_dir
        )

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    if not cfg.MULTIPROCESSING_DISTRIBUTED or (
            cfg.MULTIPROCESSING_DISTRIBUTED
            and args.rank % ngpus_per_node == 0
    ):
        dump_input = torch.rand(
            (1, 3, cfg.DATASET.INPUT_SIZE, cfg.DATASET.INPUT_SIZE)
        )
        writer_dict['writer'].add_graph(model, (dump_input, ))
        # logger.info(get_model_summary(model, dump_input, verbose=cfg.VERBOSE))

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            # args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu]
            )
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    loss_factory = MultiLossFactory(cfg).cuda()

    # Data loading code
    train_loader = make_dataloader(
        cfg, is_train=True, distributed=args.distributed
    )
    logger.info(train_loader.dataset)

    best_perf = -1
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)

    if cfg.FP16.ENABLED:
        optimizer = FP16_Optimizer(
            optimizer,
            static_loss_scale=cfg.FP16.STATIC_LOSS_SCALE,
            dynamic_loss_scale=cfg.FP16.DYNAMIC_LOSS_SCALE
        )

    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(
        final_output_dir, 'checkpoint.pth.tar')
    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

    if cfg.FP16.ENABLED:
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer.optimizer, cfg.TRAIN.LR_STEP, cfg.TRAIN.LR_FACTOR,
            last_epoch=last_epoch
        )
    else:
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, cfg.TRAIN.LR_STEP, cfg.TRAIN.LR_FACTOR,
            last_epoch=last_epoch
        )

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # train one epoch
        do_train(cfg, model, train_loader, loss_factory, optimizer, epoch,
                 final_output_dir, tb_log_dir, writer_dict, fp16=cfg.FP16.ENABLED)

        perf_indicator = epoch
        if perf_indicator >= best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        if not cfg.MULTIPROCESSING_DISTRIBUTED or (
                cfg.MULTIPROCESSING_DISTRIBUTED
                and args.rank == 0
        ):
            logger.info('=> saving checkpoint to {}'.format(final_output_dir))
            save_checkpoint({
                'epoch': epoch + 1,
                'model': cfg.MODEL.NAME,
                'state_dict': model.state_dict(),
                'best_state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir)

    final_model_state_file = os.path.join(
        final_output_dir, 'final_state{}.pth.tar'.format(gpu)
    )

    logger.info('saving final model state to {}'.format(
        final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
Ejemplo n.º 4
0
def main():
    args = parse_args()
    update_config(cfg, args)
    check_config(cfg)
    # pose_dir = prepare_output_dirs(args.outputDir)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'valid')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg,
                                                               is_train=False)

    dump_input = torch.rand(
        (1, 3, cfg.DATASET.INPUT_SIZE, cfg.DATASET.INPUT_SIZE))
    logger.info(get_model_summary(model, dump_input, verbose=cfg.VERBOSE))

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    if cfg.TEST.MODEL_FILE:
        logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
        model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=True)
    else:
        model_state_file = os.path.join(final_output_dir, 'model_best.pth.tar')
        logger.info('=> loading model from {}'.format(model_state_file))
        pretrian_model_state = torch.load(model_state_file)

        for name, param in model.state_dict().items():

            model.state_dict()[name].copy_(pretrian_model_state['1.' + name])
        # model.load_state_dict(torch.load(model_state_file))

    # model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
    model.eval()

    # Input to the model
    # batch_size = 1
    x = torch.randn(1, 3, 256, 256, requires_grad=True)
    torch_out = model(x)

    # Export the model
    torch.onnx.export(
        model,  # model being run
        x,  # model input (or a tuple for multiple inputs)
        args.
        output_onnx,  # where to save the model (can be a file or file-like object)
        export_params=
        True,  # store the trained parameter weights inside the model file
        opset_version=11,  # the ONNX version to export the model to
        do_constant_folding=
        True,  # whether to execute constant folding for optimization
        input_names=['input'],  # the model's input names
        output_names=['output1']  # the model's output names
    )
def main():
    parser = argparse.ArgumentParser(description='Please specify the mode [training/assessment/predicting]')
    parser.add_argument('--cfg',
                    help='experiment configure file name',
                    required=True,
                    type=str)
    parser.add_argument('opts',
                    help="Modify config options using the command-line",
                    default=None,
                    nargs=argparse.REMAINDER)
    parser.add_argument('--gpu',
                        help='gpu id for multiprocessing training',
                        default=-1,
                        type=int)
    parser.add_argument('--world-size',
                        default=1,
                        type=int,
                        help='number of nodes for distributed training')
    parser.add_argument('--model_path',
                    type=str)
    parser.add_argument('--image_path',
                        type=str)
    args = parser.parse_args()

    model_path = args.model_path

    update_config(cfg, args)

    ngpus_per_node = torch.cuda.device_count()

    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    if cfg.FP16.ENABLED:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if cfg.FP16.STATIC_LOSS_SCALE != 1.0:
        if not cfg.FP16.ENABLED:
            print("Warning:  if --fp16 is not used, static_loss_scale will be ignored.")

    device = 'cpu' if args.gpu == -1 else 'cuda:{}'.format(args.gpu)
    print("Use {}".format(device))

    # model initialization
    model = eval(cfg.MODEL.NAME+'.get_pose_net')(
        cfg, is_train=True
    )
     # load model state
    print("Loading model:", model_path)
    checkpoint = torch.load(model_path, map_location = 'cpu')

    if 'state_dict' not in checkpoint.keys():
        for key in list(checkpoint.keys()):
            new_key = key.replace("module.", "")
            checkpoint[new_key] = checkpoint.pop(key)
        model.load_state_dict(checkpoint, strict=False)
    else:
        state_dict = checkpoint['state_dict']
        for key in list(state_dict.keys()):
            new_key = key.replace("module.", "")
            state_dict[new_key] = state_dict.pop(key)
        model.load_state_dict(state_dict, strict=False)
        print('Model epoch {}'.format(checkpoint['epoch']))

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    #model = torch.nn.DataParallel(model)
    model.to(device)


    legend_lst = np.array([
    # 0           1               2                  3                4
    'wrist', 'thumb palm', 'thumb near palm', 'thumb near tip', 'thumb tip',
    # 5                    6                 7                8
    'index palm', 'index near palm', 'index near tip', 'index tip',
    # 9                    10                  11               12
    'middle palm', 'middle near palm', 'middle near tip', 'middle tip',
    # 13                  14               15            16
    'ring palm', 'ring near palm', 'ring near tip', 'ring tip',
    # 17                  18               19              20
    'pinky palm', 'pinky near palm', 'pinky near tip', 'pinky tip'])

    legend_dict = OrderedDict(sorted(zip(legend_lst, [i for i in range(21)]), key=lambda x:x[1]))
    
    """
    python inference.py --cfg ../experiments/RHD/RHD_w32_256x256_adam_lr1e-3.yaml --model_path  ../output/RHD_kpt/pose_hrnet/w32_256x256_adam_lr1e-3/model_best.pth.tar --image_path ../test_images/video_rgb.mp4
    python inference.py --cfg ../experiments/MHP/MHP_v1.yaml --model ../output/MHP/MHP_trainable_softmax_v2/MHP_kpt/pose_hrnet_trainable_softmax/MHP_v1/model_best.pth.tar --image_path ../test_images/
    python inference.py --cfg ../experiments/JointTraining/JointTraining_v1.yaml --model_path ../output/JointTraining/JointTraining_v1/model_best.pth.tar --image_path ../test_images/hand.mp4 --gpu cpu
    """
    def predict_one_img(image, show=False, img_path=None):
        trans = build_transforms(cfg, is_train=False)
        temp_joints = [np.ones((21,3))]
        orig_img = image.copy()
        resized_image = cv2.cvtColor(cv2.resize(image, tuple(cfg.MODEL.IMAGE_SIZE)), cv2.COLOR_RGB2BGR)
        I, _ = trans(resized_image, temp_joints)
        I = I.unsqueeze(0).to(device) if args.gpu != 'cpu' else I.unsqueeze(0)

        model.eval()
        with torch.no_grad():
            start_time = time.time()
            output, _ = model(I) # output size: 1 x 21 x 64(H) x 64(W)
            print('Inference time: {:.4f} s'.format(time.time()-start_time))
            kps_pred_np = get_final_preds(output, use_softmax=cfg.MODEL.HEATMAP_SOFTMAX).cpu().numpy().squeeze()
            #kps_pred_np = kornia.spatial_soft_argmax2d(output, normalized_coordinates=False) # 1 x 21 x 2
        #kps_pred_np =  kps_pred_np[0] * np.array([256 / cfg.MODEL.HEATMAP_SIZE[0], 256 / cfg.MODEL.HEATMAP_SIZE[0]])
        

        #kps_pred_np[:,0] += 25
        if True:            
            all_flag = False
            kps_pred_np *=  cfg.MODEL.IMAGE_SIZE[0] / cfg.MODEL.HEATMAP_SIZE[0]
            heatmap_all = np.zeros(tuple(output.shape[2:]))
            heatmap_lst = []

            fig = plt.figure()           
            ax1 = fig.add_subplot(1,2,1)
            ax1.imshow(resized_image)

            for kp in range(0,21): 
                heatmap = output[0][kp].cpu().numpy()
                heatmap_lst.append(heatmap)
                heatmap_all += heatmap

            if not all_flag:         
                ax2 = fig.add_subplot(1,2,2)
                heatmap_cat = np.vstack((np.hstack(heatmap_lst[0:7]), np.hstack(heatmap_lst[7:14]), np.hstack(heatmap_lst[14:21])))
                print(heatmap_cat.shape)
                #hm = 255 * output[0][kp] / hms[0][kp].sum()
                ax1.scatter(kps_pred_np[kp][0], kps_pred_np[kp][1], linewidths=10)
                ax2.imshow(heatmap_cat)

                plt.title(kps_pred_np[kp].tolist())
                plt.show()

            if all_flag:
                
                ax2 = fig.add_subplot(1,2,2)
                ax1.imshow(resized_image)
                ax2.imshow(heatmap_all / heatmap_all.max())
                plt.show()
        else:
            kps_pred_np[:,0] *= orig_img.shape[1] / cfg.MODEL.HEATMAP_SIZE[0]
            kps_pred_np[:,1] *= orig_img.shape[0] / cfg.MODEL.HEATMAP_SIZE[0]
            kps_pred_np[:,0] += 0
            fig = plt.figure()
            fig.set_tight_layout(True)
            plt.imshow(cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB))

            plt.plot([kps_pred_np[0][0], kps_pred_np[legend_dict['thumb palm']][0]], [kps_pred_np[0][1], kps_pred_np[legend_dict['thumb palm']][1]], c='r', marker='.')
            plt.plot(kps_pred_np[1:5,0], kps_pred_np[1:5,1], c='r', marker='.', label='Thumb')
            plt.plot([kps_pred_np[0][0], kps_pred_np[legend_dict['index palm']][0]], [kps_pred_np[0][1], kps_pred_np[legend_dict['index palm']][1]], c='g', marker='.')
            plt.plot(kps_pred_np[5:9,0], kps_pred_np[5:9,1], c='g', marker='.', label='Index')
            plt.plot([kps_pred_np[0][0], kps_pred_np[legend_dict['middle palm']][0]], [kps_pred_np[0][1], kps_pred_np[legend_dict['middle palm']][1]], c='b', marker='.')
            plt.plot(kps_pred_np[9:13,0], kps_pred_np[9:13,1], c='b', marker='.', label='Middle')
            plt.plot([kps_pred_np[0][0], kps_pred_np[legend_dict['ring palm']][0]], [kps_pred_np[0][1], kps_pred_np[legend_dict['ring palm']][1]], c='m', marker='.')
            plt.plot(kps_pred_np[13:17,0], kps_pred_np[13:17,1], c='m', marker='.', label='Ring')
            plt.plot([kps_pred_np[0][0], kps_pred_np[legend_dict['pinky palm']][0]], [kps_pred_np[0][1], kps_pred_np[legend_dict['pinky palm']][1]], c='y', marker='.')
            plt.plot(kps_pred_np[17:21,0], kps_pred_np[17:21,1], c='y', marker='.', label='Pinky')
            plt.title('Prediction')
            if img_path:
                plt.title(img_path)
            plt.axis('off')
            plt.legend(bbox_to_anchor=(1.04, 1), loc="upper right", ncol=1, mode="expand", borderaxespad=0.)


        fig.canvas.draw()
        # Get the RGBA buffer from the figure
        buf = fig.canvas.buffer_rgba()

        if show:
            plt.show()
            print(kps_pred_np)


        return np.asarray(buf), kps_pred_np


    if os.path.isdir(args.image_path):
        # a bunch of images
        imgpath_lst = os.listdir(args.image_path)
        for p in imgpath_lst:
            if p.endswith('mp4'):continue
            print(os.path.join(args.image_path, p))
            I = cv2.imread(os.path.join(args.image_path, p), cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
            predict_one_img(I, True, os.path.join(args.image_path, p))
    else:
        try:
            # an image
            I = cv2.imread(args.image_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
            predict_one_img(I, True, args.image_path)
        except Exception as e:
            print(e)
            # video
            v = cv2.VideoCapture(args.image_path)
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            videoWriter = cv2.VideoWriter('pred_results.mp4', fourcc, 10, (640, 480))
            count = 0
            pose2d_pred_lst = []
            while(v.isOpened()):
                ret, frame = v.read()
                count +=1 
                if count < 0: continue
                if ret == False:
                    break
                result, pose2d_pred = predict_one_img(frame, show=False)
                pose2d_pred_lst.append(pose2d_pred)
                print(result.shape)
                videoWriter.write(result[:,:,0:-1])
                if count == 130:
                    break
            np.savetxt('./pose2d_pred.txt', np.concatenate(pose2d_pred_lst, axis=0))
            v.release()
            videoWriter.release()
Ejemplo n.º 6
0
    args = parse_args()
    update_config(cfg, args)
    check_config(cfg)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'valid')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # CUDA settings
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg,
                                                               is_train=False)

    rand_input = torch.randn(1, 3, cfg.DATASET.INPUT_SIZE,
                             cfg.DATASET.INPUT_SIZE)
    logger.info(get_model_summary(model, rand_input, verbose=cfg.VERBOSE))

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
    ])
Ejemplo n.º 7
0
def main_worker(gpus, ngpus_per_node, args, final_output_dir, tb_log_dir):
    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    #os.environ['CUDA_VISIBLE_DEVICES']=gpus

    # if len(gpus) == 1:
    #     gpus = int(gpus)

    update_config(cfg, args)

    #test(cfg, args)

    # logger setting
    logger, _ = setup_logger(final_output_dir, args.rank, 'train')

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # model initilization
    model = {
        "ransac": RANSACTriangulationNet,
        "alg": AlgebraicTriangulationNet,
        "vol": VolumetricTriangulationNet,
        "vol_CPM": VolumetricTriangulationNet_CPM,
        "FTL": FTLMultiviewNet
    }[cfg.MODEL.NAME](cfg)

    discriminator = Discriminator(cfg)

    # load pretrained model before DDP initialization
    if cfg.AUTO_RESUME:
        checkpoint_file = os.path.join(final_output_dir, 'model_best.pth.tar')
        if os.path.exists(checkpoint_file):
            checkpoint = torch.load(checkpoint_file,
                                    map_location=torch.device('cpu'))
            state_dict = checkpoint['state_dict']
            D_state_dict = checkpoint['D_state_dict']

            for key in list(state_dict.keys()):
                new_key = key.replace("module.", "")
                state_dict[new_key] = state_dict.pop(key)
            for key in list(D_state_dict.keys()):
                new_key = key.replace("module.", "")
                D_state_dict[new_key] = D_state_dict.pop(key)

            model.load_state_dict(state_dict)
            discriminator.load_state_dict(D_state_dict)
            logger.info("=> Loading checkpoint '{}' (epoch {})".format(
                checkpoint_file, checkpoint['epoch']))
        else:
            print('[Warning] Checkpoint file not found! Wrong path: {}'.format(
                checkpoint_file))

    elif cfg.MODEL.HRNET_PRETRAINED:
        logger.info("=> loading a pretrained model '{}'".format(
            cfg.MODEL.PRETRAINED))
        checkpoint = torch.load(cfg.MODEL.HRNET_PRETRAINED)

        state_dict = checkpoint['state_dict']
        for key in list(state_dict.keys()):
            new_key = key.replace("module.", "")
            state_dict[new_key] = state_dict.pop(key)

        model.load_state_dict(state_dict)

    # initiliaze a optimizer
    # optimizer must be initilized after model initilization
    if cfg.MODEL.TRIANGULATION_MODEL_NAME == "vol":
        optimizer = torch.optim.Adam([{
            'params': model.backbone.parameters(),
            'initial_lr': cfg.TRAIN.LR
        }, {
            'params':
            model.process_features.parameters(),
            'initial_lr':
            cfg.TRAIN.PROCESS_FEATURE_LR
            if hasattr(cfg.TRAIN, "PROCESS_FEATURE_LR") else cfg.TRAIN.LR
        }, {
            'params':
            model.volume_net.parameters(),
            'initial_lr':
            cfg.TRAIN.VOLUME_NET_LR
            if hasattr(cfg.TRAIN, "VOLUME_NET_LR") else cfg.TRAIN.LR
        }],
                                     lr=cfg.TRAIN.LR)
    else:
        optimizer = torch.optim.Adam(
            [{
                'params': filter(lambda p: p.requires_grad,
                                 model.parameters()),
                'initial_lr': cfg.TRAIN.LR
            }],
            lr=cfg.TRAIN.LR)

    D_optimizer = torch.optim.RMSprop([{
        'params':
        filter(lambda p: p.requires_grad, discriminator.parameters()),
        'initial_lr':
        cfg.TRAIN.LR
    }],
                                      lr=cfg.TRAIN.LR)

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(os.path.join(this_dir, '../lib/models', 'triangulation.py'),
                 final_output_dir)
    # copy configuration file
    config_dir = args.cfg
    shutil.copy2(os.path.join(args.cfg), final_output_dir)

    # calculate GFLOPS
    # dump_input = torch.rand(
    #     (1, 4, 3, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[0])
    # )

    # logger.info(get_model_summary(model, dump_input, verbose=cfg.VERBOSE))

    # FP16 SETTING
    if cfg.FP16.ENABLED:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if cfg.FP16.STATIC_LOSS_SCALE != 1.0:
        if not cfg.FP16.ENABLED:
            print(
                "Warning:  if --fp16 is not used, static_loss_scale will be ignored."
            )

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    if cfg.MODEL.SYNC_BN and not cfg.DISTRIBUTED:
        print(
            'Warning: Sync BatchNorm is only supported in distributed training.'
        )

    if cfg.FP16.ENABLED:
        optimizer = FP16_Optimizer(
            optimizer,
            static_loss_scale=cfg.FP16.STATIC_LOSS_SCALE,
            dynamic_loss_scale=cfg.FP16.DYNAMIC_LOSS_SCALE,
            verbose=False)

    # Distributed Computing
    master = True
    if cfg.DISTRIBUTED:  # This block is not available
        args.local_rank += int(gpus[0])
        print('This process is using GPU', args.local_rank)
        device = args.local_rank
        master = device == int(gpus[0])
        dist.init_process_group(backend='nccl')
        if cfg.MODEL.SYNC_BN:
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if gpus is not None:
            torch.cuda.set_device(device)
            model.cuda(device)
            discriminator.cuda(device)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            # workers = int(workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[device],
                output_device=device,
                find_unused_parameters=True)
            discriminator = torch.nn.parallel.DistributedDataParallel(
                discriminator,
                device_ids=[device],
                output_device=device,
                find_unused_parameters=True)
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    else:  # implement this block
        gpu_ids = eval('[' + gpus + ']')
        device = gpu_ids[0]
        print('This process is using GPU', str(device))
        model = torch.nn.DataParallel(model, gpu_ids).cuda(device)
        discriminator = torch.nn.DataParallel(discriminator,
                                              gpu_ids).cuda(device)

    # Prepare loss functions
    criterion = {}
    if cfg.LOSS.WITH_HEATMAP_LOSS:
        criterion['heatmap_loss'] = HeatmapLoss().cuda(device)
    if cfg.LOSS.WITH_POSE2D_LOSS:
        criterion['pose2d_loss'] = JointsMSELoss().cuda(device)
    if cfg.LOSS.WITH_POSE3D_LOSS:
        criterion['pose3d_loss'] = Joints3DMSELoss().cuda(device)
    if cfg.LOSS.WITH_VOLUMETRIC_CE_LOSS:
        criterion['volumetric_ce_loss'] = VolumetricCELoss().cuda(device)
    if cfg.LOSS.WITH_BONE_LOSS:
        criterion['bone_loss'] = BoneLengthLoss().cuda(device)
    if cfg.LOSS.WITH_TIME_CONSISTENCY_LOSS:
        criterion['time_consistency_loss'] = Joints3DMSELoss().cuda(device)
    if cfg.LOSS.WITH_KCS_LOSS:
        criterion['KCS_loss'] = None
    if cfg.LOSS.WITH_JOINTANGLE_LOSS:
        criterion['jointangle_loss'] = JointAngleLoss().cuda(device)

    best_perf = 1e9
    best_model = False
    last_epoch = -1

    # load history
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        begin_epoch = checkpoint['epoch'] + 1
        best_perf = checkpoint['loss']
        optimizer.load_state_dict(checkpoint['optimizer'])
        D_optimizer.load_state_dict(checkpoint['D_optimizer'])

        if 'train_global_steps' in checkpoint.keys() and \
        'valid_global_steps' in checkpoint.keys():
            writer_dict['train_global_steps'] = checkpoint[
                'train_global_steps']
            writer_dict['valid_global_steps'] = checkpoint[
                'valid_global_steps']

    # Floating point 16 mode
    if cfg.FP16.ENABLED:
        logger.info("=> Using FP16 mode")
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer.optimizer,
            cfg.TRAIN.LR_STEP,
            cfg.TRAIN.LR_FACTOR,
            last_epoch=begin_epoch)
    else:
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            cfg.TRAIN.LR_STEP,
            cfg.TRAIN.LR_FACTOR,
            last_epoch=begin_epoch)

    # Data loading code
    train_loader_dict = make_dataloader(cfg,
                                        is_train=True,
                                        distributed=cfg.DISTRIBUTED)
    valid_loader_dict = make_dataloader(cfg,
                                        is_train=False,
                                        distributed=cfg.DISTRIBUTED)

    for i, (dataset_name,
            train_loader) in enumerate(train_loader_dict.items()):
        logger.info(
            'Training Loader {}/{}:\n'.format(i + 1, len(train_loader_dict)) +
            str(train_loader.dataset))
    for i, (dataset_name,
            valid_loader) in enumerate(valid_loader_dict.items()):
        logger.info('Validation Loader {}/{}:\n'.format(
            i + 1, len(valid_loader_dict)) + str(valid_loader.dataset))

    #writer_dict['writer'].add_graph(model, (dump_input, ))
    """
    Start training
    """
    start_time = time.time()

    with torch.autograd.set_detect_anomaly(True):
        for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
            epoch_start_time = time.time()
            # shuffle datasets with the sample random seed
            if cfg.DISTRIBUTED:
                for data_loader in train_loader_dict.values():
                    data_loader.sampler.set_epoch(epoch)
            # train for one epoch
            logger.info('Start training [{}/{}]'.format(
                epoch, cfg.TRAIN.END_EPOCH - 1))
            train(epoch,
                  cfg,
                  args,
                  master,
                  train_loader_dict, [model, discriminator],
                  criterion, [optimizer, D_optimizer],
                  final_output_dir,
                  tb_log_dir,
                  writer_dict,
                  logger,
                  device,
                  fp16=cfg.FP16.ENABLED)

            # In PyTorch 1.1.0 and later, you should call `lr_scheduler.step()` after `optimizer.step()`.
            lr_scheduler.step()

            # evaluate on validation set
            if not cfg.WITHOUT_EVAL:
                logger.info('Start evaluating [{}/{}]'.format(
                    epoch, cfg.TRAIN.END_EPOCH - 1))
                with torch.no_grad():
                    recorder = validate(cfg, args, master, valid_loader_dict,
                                        [model, discriminator], criterion,
                                        final_output_dir, tb_log_dir,
                                        writer_dict, logger, device)

                val_total_loss = recorder.avg_total_loss

                if val_total_loss < best_perf:
                    logger.info(
                        'This epoch yielded a better model with total loss {:.4f} < {:.4f}.'
                        .format(val_total_loss, best_perf))
                    best_perf = val_total_loss
                    best_model = True
                else:
                    best_model = False

            else:
                val_total_loss = 0
                best_model = True

            logger.info('=> saving checkpoint to {}'.format(final_output_dir))
            save_checkpoint(
                {
                    'epoch': epoch,
                    'model': cfg.EXP_NAME + '.' + cfg.MODEL.NAME,
                    'state_dict': model.state_dict(),
                    'D_state_dict': discriminator.state_dict(),
                    'loss': val_total_loss,
                    'optimizer': optimizer.state_dict(),
                    'D_optimizer': D_optimizer.state_dict(),
                    'train_global_steps': writer_dict['train_global_steps'],
                    'valid_global_steps': writer_dict['valid_global_steps']
                }, best_model, final_output_dir)

            print('\nEpoch {} spent {:.2f} hours\n'.format(
                epoch, (time.time() - epoch_start_time) / 3600))

            #if epoch == 3:break
    if master:
        final_model_state_file = os.path.join(
            final_output_dir, 'final_state{}.pth.tar'.format(gpus))
        logger.info(
            '=> saving final model state to {}'.format(final_model_state_file))
        torch.save(model.state_dict(), final_model_state_file)
        writer_dict['writer'].close()

        print(
            '\n[Training Accomplished] {} epochs spent {:.2f} hours\n'.format(
                cfg.TRAIN.END_EPOCH - begin_epoch + 1,
                (time.time() - start_time) / 3600))
def main():
    args = parse_args()
    update_config(cfg, args)
    cfg.defrost()
    cfg.freeze()

    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model_path = args.model_path
    is_vis = args.is_vis

    gpus = ','.join([str(i) for i in cfg.GPUS])
    gpu_ids = eval('[' + gpus + ']')

    if cfg.FP16.ENABLED:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if cfg.FP16.STATIC_LOSS_SCALE != 1.0:
        if not cfg.FP16.ENABLED:
            print(
                "Warning:  if --fp16 is not used, static_loss_scale will be ignored."
            )

    # model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
    #     cfg, is_train=True
    # )

    if 'pose_hrnet' in cfg.MODEL.NAME:
        model = {
            "pose_hrnet": pose_hrnet.get_pose_net,
            "pose_hrnet_softmax": pose_hrnet_softmax.get_pose_net
        }[cfg.MODEL.NAME](cfg, is_train=True)
    else:
        model = {
            "ransac": RANSACTriangulationNet,
            "alg": AlgebraicTriangulationNet,
            "vol": VolumetricTriangulationNet,
            "vol_CPM": VolumetricTriangulationNet_CPM,
            "FTL": FTLMultiviewNet
        }[cfg.MODEL.NAME](cfg, is_train=False)

    # load model state
    if model_path:
        print("Loading model:", model_path)
        ckpt = torch.load(model_path,
                          map_location='cpu' if args.gpu == -1 else 'cuda:0')
        if 'state_dict' not in ckpt.keys():
            state_dict = ckpt
        else:
            state_dict = ckpt['state_dict']
            print('Model epoch {}'.format(ckpt['epoch']))

        for key in list(state_dict.keys()):
            new_key = key.replace("module.", "")
            state_dict[new_key] = state_dict.pop(key)

        model.load_state_dict(state_dict, strict=True)

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    if cfg.MODEL.SYNC_BN and not args.distributed:
        print(
            'Warning: Sync BatchNorm is only supported in distributed training.'
        )

    device = torch.device('cuda:' + str(args.gpu) if args.gpu != -1 else 'cpu')

    model.to(device)

    model.eval()

    # image transformer
    transform = build_transforms(cfg, is_train=False)

    inference_dataset = eval('dataset.' + cfg.DATASET.TEST_DATASET[0])(
        cfg, cfg.DATASET.TEST_SET, transform=transform)

    data_loader = torch.utils.data.DataLoader(inference_dataset,
                                              batch_size=1,
                                              shuffle=True,
                                              num_workers=0,
                                              pin_memory=False)

    print('\nValidation loader information:\n' + str(data_loader.dataset))

    with torch.no_grad():
        pose2d_mse_loss = JointsMSELoss().to(
            device) if args.gpu != -1 else JointsMSELoss()
        pose3d_mse_loss = Joints3DMSELoss().to(
            device) if args.gpu != -1 else Joints3DMSELoss()
        orig_width, orig_height = inference_dataset.orig_img_size
        heatmap_size = cfg.MODEL.HEATMAP_SIZE
        count = 4
        for i, ret in enumerate(data_loader):
            # orig_imgs: 1 x 4 x 480 x 640 x 3
            # imgs: 1 x 4 x 3 x H x W
            # pose2d_gt (bounded in 64 x 64): 1 x 4 x 21 x 2
            # pose3d_gt: 1 x 21 x 3
            # visibility: 1 x 4 x 21
            # extrinsic matrix: 1 x 4 x 3 x 4
            # intrinsic matrix: 1 x 3 x 3
            if not (i % 67 == 0): continue

            imgs = ret['imgs'].to(device)
            orig_imgs = ret['orig_imgs']
            pose2d_gt, pose3d_gt, visibility = ret['pose2d'], ret[
                'pose3d'], ret['visibility']
            extrinsic_matrices, intrinsic_matrices = ret[
                'extrinsic_matrices'], ret['intrinsic_matrix']
            # somtimes intrisic_matrix has a shape of 3x3 or b x 3x3
            intrinsic_matrix = intrinsic_matrices[0] if len(
                intrinsic_matrices.shape) == 3 else intrinsic_matrices

            start_time = time.time()
            if 'pose_hrnet' in cfg.MODEL.NAME:
                pose3d_gt = pose3d_gt.to(device)

                heatmaps, _ = model(imgs[0])  # N_views x 21 x 64 x 64
                pose2d_pred = get_final_preds(heatmaps,
                                              cfg)  # N_views x 21 x 2
                proj_matrices = (intrinsic_matrix @ extrinsic_matrices).to(
                    device)  # b x v x 3 x 4

                # rescale to the original image before DLT
                pose2d_pred[:, :, 0:1] *= orig_width / heatmap_size[0]
                pose2d_pred[:, :, 1:2] *= orig_height / heatmap_size[0]
                # 3D world coordinate 1 x 21 x 3
                pose3d_pred = DLT_pytorch(pose2d_pred,
                                          proj_matrices.squeeze()).unsqueeze(0)

            elif 'alg' == cfg.MODEL.NAME or 'ransac' == cfg.MODEL.NAME:
                # the predicted 2D poses have been rescaled inside the triangulation model
                # pose2d_pred: 1 x N_views x 21 x 2
                # pose3d_pred: 1 x 21 x 3
                proj_matrices = (intrinsic_matrix @ extrinsic_matrices
                                 )  # b x v x 3 x 4

                pose3d_pred,\
                pose2d_pred,\
                heatmaps,\
                confidences_pred = model(imgs, proj_matrices.to(device))

            elif "vol" in cfg.MODEL.NAME:
                intrinsic_matrix = update_after_resize(
                    intrinsic_matrix, (orig_height, orig_width),
                    tuple(heatmap_size))
                proj_matrices = (intrinsic_matrix @ extrinsic_matrices).to(
                    device)  # b x v x 3 x 4

                # pose3d_pred (torch.tensor) b x 21 x 3
                # pose2d_pred (torch.tensor) b x v x 21 x 2 NOTE: the estimated 2D poses are located in the heatmap size 64(W) x 64(H)
                # heatmaps_pred (torch.tensor) b x v x 21 x 64 x 64
                # volumes_pred (torch.tensor)
                # confidences_pred (torch.tensor)
                # cuboids_pred (list)
                # coord_volumes_pred (torch.tensor)
                # base_points_pred (torch.tensor) b x v x 1 x 2
                if cfg.MODEL.BACKBONE_NAME == 'CPM_volumetric':
                    centermaps = ret['centermaps'].to(device)

                    pose3d_pred,\
                    pose2d_pred,\
                    heatmaps_pred,\
                    volumes_pred,\
                    confidences_pred,\
                    coord_volumes_pred,\
                    base_points_pred\
                        = model(imgs, centermaps, proj_matrices)
                else:
                    pose3d_pred,\
                    pose2d_pred,\
                    heatmaps,\
                    volumes_pred,\
                    confidences_pred,\
                    coord_volumes_pred,\
                    base_points_pred\
                        = model(imgs, proj_matrices)

                pose2d_pred[:, :, :, 0:1] *= orig_width / heatmap_size[0]
                pose2d_pred[:, :, :, 1:2] *= orig_height / heatmap_size[0]

            elif 'FTL' == cfg.MODEL.NAME:
                # pose2d_pred: 1 x 4 x 21 x 2
                # pose3d_pred: 1 x 21 x 3
                heatmaps, pose2d_pred, pose3d_pred = model(
                    imgs.to(device), extrinsic_matrices.to(device),
                    intrinsic_matrix.to(device))

                print(pose2d_pred)
                pose2d_pred = torch.cat((pose2d_pred[:, :, :, 0:1] * 640 / 64,
                                         pose2d_pred[:, :, :, 1:2] * 480 / 64),
                                        dim=-1)

            # N_views x 21 x 2
            end_time = time.time()
            print('3D pose inference time {:.1f} ms'.format(
                1000 * (end_time - start_time)))
            pose3d_EPE = pose3d_mse_loss(pose3d_pred[:, 1:],
                                         pose3d_gt[:, 1:].to(device)).item()
            print('Pose3d MSE: {:.4f}\n'.format(pose3d_EPE))

            # if pose3d_EPE > 35:
            #     input()
            #     continue
            # 2D errors
            pose2d_gt[:, :, :, 0] *= orig_width / heatmap_size[0]
            pose2d_gt[:, :, :, 1] *= orig_height / heatmap_size[1]

            # for k in range(21):
            #     print(pose2d_gt[0,k].tolist(), pose2d_pred[0,k].tolist())
            # input()

            visualize(args=args,
                      imgs=np.squeeze(orig_imgs[0].numpy()),
                      pose2d_gt=np.squeeze(pose2d_gt.cpu().numpy()),
                      pose2d_pred=np.squeeze(pose2d_pred.cpu().numpy()),
                      pose3d_gt=np.squeeze(pose3d_gt.cpu().numpy()),
                      pose3d_pred=np.squeeze(pose3d_pred.cpu().numpy()))
def main():
    args = parse_args()
    update_config(cfg, args)
    check_config(cfg)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'valid'
    )

    logger.info(pprint.pformat(args))

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
        cfg, is_train=False
    )

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    if cfg.TEST.MODEL_FILE:
        logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
        model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=True)
    else:
        model_state_file = os.path.join(
            final_output_dir, 'model_best.pth.tar'
        )
        logger.info('=> loading model from {}'.format(model_state_file))
        model.load_state_dict(torch.load(model_state_file))

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
    model.eval()

    data_loader, test_dataset = make_test_dataloader(cfg)

    if cfg.MODEL.NAME == 'pose_hourglass':
        transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
            ]
        )
    else:
        transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ]
        )

    parser = HeatmapParser(cfg)

    vid_file = 0 # Or video file path
    print("Opening Camera " + str(vid_file))
    cap = cv2.VideoCapture(vid_file)

    while True:
        ret, image = cap.read()

        a = datetime.datetime.now()

        # size at scale 1.0
        base_size, center, scale = get_multi_scale_size(
            image, cfg.DATASET.INPUT_SIZE, 1.0, min(cfg.TEST.SCALE_FACTOR)
        )

        with torch.no_grad():
            final_heatmaps = None
            tags_list = []
            for idx, s in enumerate(sorted(cfg.TEST.SCALE_FACTOR, reverse=True)):
                input_size = cfg.DATASET.INPUT_SIZE
                image_resized, center, scale = resize_align_multi_scale(
                    image, input_size, s, min(cfg.TEST.SCALE_FACTOR)
                )
                image_resized = transforms(image_resized)
                image_resized = image_resized.unsqueeze(0).cuda()

                outputs, heatmaps, tags = get_multi_stage_outputs(
                    cfg, model, image_resized, cfg.TEST.FLIP_TEST,
                    cfg.TEST.PROJECT2IMAGE, base_size
                )

                final_heatmaps, tags_list = aggregate_results(
                    cfg, s, final_heatmaps, tags_list, heatmaps, tags
                )

            final_heatmaps = final_heatmaps / float(len(cfg.TEST.SCALE_FACTOR))
            tags = torch.cat(tags_list, dim=4)
            grouped, scores = parser.parse(
                final_heatmaps, tags, cfg.TEST.ADJUST, cfg.TEST.REFINE
            )

            final_results = get_final_preds(
                grouped, center, scale,
                [final_heatmaps.size(3), final_heatmaps.size(2)]
            )

        b = datetime.datetime.now()
        inf_time = (b - a).total_seconds()*1000
        print("Inf time {} ms".format(inf_time))

        # Display the resulting frame
        for person in final_results:
            color = np.random.randint(0, 255, size=3)
            color = [int(i) for i in color]
            add_joints(image, person, color, test_dataset.name, cfg.TEST.DETECTION_THRESHOLD)

        image = cv2.putText(image, "{:.2f} ms / frame".format(inf_time), (40, 40),
                            cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)
        cv2.imshow('frame', image)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
def main():
    args = parse_args()

    update_config(cfg, args)
    cfg.defrost()
    cfg.freeze()

    if args.is_vis:
        result_dir = prefix + cfg.EXP_NAME
        mse2d_lst = np.loadtxt(os.path.join(result_dir,
                                            'mse2d_each_joint.txt'))
        mse3d_lst = np.loadtxt(os.path.join(result_dir,
                                            'mse3d_each_joint.txt'))
        PCK2d_lst = np.loadtxt(os.path.join(result_dir, 'PCK2d.txt'))
        PCK3d_lst = np.loadtxt(os.path.join(result_dir, 'PCK3d.txt'))

        plot_performance(PCK2d_lst[1, :], PCK2d_lst[0, :], PCK3d_lst[1, :],
                         PCK3d_lst[0, :], mse2d_lst, mse3d_lst)
        exit()

    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model_path = args.model_path
    is_vis = args.is_vis

    gpus = ','.join([str(i) for i in cfg.GPUS])
    gpu_ids = eval('[' + gpus + ']')

    if cfg.FP16.ENABLED:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if cfg.FP16.STATIC_LOSS_SCALE != 1.0:
        if not cfg.FP16.ENABLED:
            print(
                "Warning:  if --fp16 is not used, static_loss_scale will be ignored."
            )

    if 'pose_hrnet' in cfg.MODEL.NAME:
        model = {
            "pose_hrnet": pose_hrnet.get_pose_net,
            "pose_hrnet_softmax": pose_hrnet_softmax.get_pose_net
        }[cfg.MODEL.NAME](cfg, is_train=True)
    else:
        model = {
            "ransac": RANSACTriangulationNet,
            "alg": AlgebraicTriangulationNet,
            "vol": VolumetricTriangulationNet,
            "vol_CPM": VolumetricTriangulationNet_CPM,
            "FTL": FTLMultiviewNet
        }[cfg.MODEL.NAME](cfg, is_train=False)

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    if cfg.MODEL.SYNC_BN and not args.distributed:
        print(
            'Warning: Sync BatchNorm is only supported in distributed training.'
        )

    # load model state
    if model_path:
        print("Loading model:", model_path)
        ckpt = torch.load(model_path,
                          map_location='cpu' if args.gpu == -1 else 'cuda:0')
        if 'state_dict' not in ckpt.keys():
            state_dict = ckpt
        else:
            state_dict = ckpt['state_dict']
            print('Model epoch {}'.format(ckpt['epoch']))

        for key in list(state_dict.keys()):
            new_key = key.replace("module.", "")
            state_dict[new_key] = state_dict.pop(key)

        model.load_state_dict(state_dict, strict=False)

    device = torch.device('cuda:' + str(args.gpu) if args.gpu != -1 else 'cpu')

    model.to(device)

    model.eval()

    # image transformer
    transform = build_transforms(cfg, is_train=False)

    inference_dataset = eval('dataset.' + cfg.DATASET.DATASET[0])(
        cfg, cfg.DATASET.TEST_SET, transform=transform)
    inference_dataset.n_views = eval(args.views)
    batch_size = args.batch_size
    if platform.system() == 'Linux':  # for linux
        data_loader = torch.utils.data.DataLoader(inference_dataset,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  num_workers=8,
                                                  pin_memory=False)
    else:  # for windows
        batch_size = 1
        data_loader = torch.utils.data.DataLoader(inference_dataset,
                                                  batch_size=batch_size,
                                                  shuffle=False,
                                                  num_workers=0,
                                                  pin_memory=False)

    print('\nEvaluation loader information:\n' + str(data_loader.dataset))
    print('Evaluation batch size: {}\n'.format(batch_size))

    th2d_lst = np.array([i for i in range(1, 50)])
    PCK2d_lst = np.zeros((len(th2d_lst), ))
    mse2d_lst = np.zeros((21, ))
    th3d_lst = np.array([i for i in range(1, 51)])
    PCK3d_lst = np.zeros((len(th3d_lst), ))
    mse3d_lst = np.zeros((21, ))
    visibility_lst = np.zeros((21, ))
    with torch.no_grad():
        start_time = time.time()
        pose2d_mse_loss = JointsMSELoss().cuda(
            args.gpu) if args.gpu != -1 else JointsMSELoss()
        pose3d_mse_loss = Joints3DMSELoss().cuda(
            args.gpu) if args.gpu != -1 else Joints3DMSELoss()

        infer_time = [0, 0]
        start_time = time.time()
        n_valid = 0
        model.orig_img_size = inference_dataset.orig_img_size
        orig_width, orig_height = model.orig_img_size
        heatmap_size = cfg.MODEL.HEATMAP_SIZE

        for i, ret in enumerate(data_loader):
            # ori_imgs: b x 4 x 480 x 640 x 3
            # imgs: b x 4 x 3 x H x W
            # pose2d_gt: b x 4 x 21 x 2 (have not been transformed)
            # pose3d_gt: b x 21 x 3
            # visibility: b x 4 x 21
            # extrinsic matrix: b x 4 x 3 x 4
            # intrinsic matrix: b x 3 x 3
            # if i < count: continue
            imgs = ret['imgs'].to(device)
            orig_imgs = ret['orig_imgs']
            pose2d_gt, pose3d_gt, visibility = ret['pose2d'], ret[
                'pose3d'], ret['visibility']
            extrinsic_matrices, intrinsic_matrices = ret[
                'extrinsic_matrices'], ret['intrinsic_matrix']
            # somtimes intrisic_matrix has a shape of 3x3 or b x 3x3
            intrinsic_matrix = intrinsic_matrices[0] if len(
                intrinsic_matrices.shape) == 3 else intrinsic_matrices

            batch_size = orig_imgs.shape[0]
            n_joints = pose2d_gt.shape[2]
            pose2d_gt = pose2d_gt.view(
                -1, *pose2d_gt.shape[2:]).numpy()  # b*v x 21 x 2
            pose3d_gt = pose3d_gt.numpy()  # b x 21 x 3
            visibility = visibility.view(
                -1, visibility.shape[2]).numpy()  # b*v x 21

            if 'pose_hrnet' in cfg.MODEL.NAME:
                s1 = time.time()
                heatmaps, _ = model(imgs.view(
                    -1, *imgs.shape[2:]))  # b*v x 21 x 64 x 64
                pose2d_pred = get_final_preds(heatmaps, cfg).view(
                    batch_size, -1, n_joints, 2
                )  # b x v x 21 x 2 NOTE: the estimated 2D poses are located in the heatmap size 64(W) x 64(H)
                proj_matrices = (intrinsic_matrix @ extrinsic_matrices).to(
                    device)  # b x v x 3 x 4
                # rescale to the original image before DLT
                pose2d_pred[:, :, :, 0:1] *= orig_width / heatmap_size[0]
                pose2d_pred[:, :, :, 1:2] *= orig_height / heatmap_size[0]

                # 3D world coordinate 1 x 21 x 3
                pose3d_pred = torch.cat([
                    DLT_sii_pytorch(pose2d_pred[:, :, k],
                                    proj_matrices).unsqueeze(1)
                    for k in range(n_joints)
                ],
                                        dim=1)  # b x 21 x 3

                if i > 20:
                    infer_time[0] += 1
                    infer_time[1] += time.time() - s1
                    #print('FPS {:.1f}'.format(infer_time[0]/infer_time[1]))

            elif 'alg' == cfg.MODEL.NAME or 'ransac' == cfg.MODEL.NAME:
                s1 = time.time()
                # pose2d_pred: b x N_views x 21 x 2
                # NOTE: the estimated 2D poses are located in the original image of size 640(W) x 480(H)]
                # pose3d_pred: b x 21 x 3 [world coord]
                proj_matrices = (intrinsic_matrix @ extrinsic_matrices).to(
                    device)  # b x v x 3 x 4
                pose3d_pred,\
                pose2d_pred,\
                heatmaps,\
                confidences_pred = model(imgs.to(device), proj_matrices.to(device))
                if i > 20:
                    infer_time[0] += 1
                    infer_time[1] += time.time() - s1

            elif "vol" in cfg.MODEL.NAME:
                intrinsic_matrix = update_after_resize(
                    intrinsic_matrix, (orig_height, orig_width),
                    tuple(heatmap_size))
                proj_matrices = (intrinsic_matrix @ extrinsic_matrices).to(
                    device)  # b x v x 3 x 4
                s1 = time.time()

                # pose3d_pred (torch.tensor) b x 21 x 3
                # pose2d_pred (torch.tensor) b x v x 21 x 2 NOTE: the estimated 2D poses are located in the heatmap size 64(W) x 64(H)
                # heatmaps_pred (torch.tensor) b x v x 21 x 64 x 64
                # volumes_pred (torch.tensor)
                # confidences_pred (torch.tensor)
                # cuboids_pred (list)
                # coord_volumes_pred (torch.tensor)
                # base_points_pred (torch.tensor) b x v x 1 x 2
                if cfg.MODEL.BACKBONE_NAME == 'CPM_volumetric':
                    centermaps = ret['centermaps'].to(device)
                    heatmaps_gt = ret['heatmaps']

                    pose3d_pred,\
                    pose2d_pred,\
                    heatmaps_pred,\
                    volumes_pred,\
                    confidences_pred,\
                    coord_volumes_pred,\
                    base_points_pred\
                        = model(imgs, centermaps, proj_matrices)
                else:
                    pose3d_pred,\
                    pose2d_pred,\
                    heatmaps,\
                    volumes_pred,\
                    confidences_pred,\
                    coord_volumes_pred,\
                    base_points_pred\
                        = model(imgs, proj_matrices)

                if i > 20:
                    infer_time[0] += 1
                    infer_time[1] += time.time() - s1

                pose2d_pred[:, :, :, 0:1] *= orig_width / heatmap_size[0]
                pose2d_pred[:, :, :, 1:2] *= orig_height / heatmap_size[1]

            # 2D errors
            pose2d_gt[:, :, 0] *= orig_width / heatmap_size[0]
            pose2d_gt[:, :, 1] *= orig_height / heatmap_size[1]

            pose2d_pred = pose2d_pred.view(-1, n_joints,
                                           2).cpu().numpy()  # b*v x 21 x 2
            for k in range(21):
                print(pose2d_gt[0, k].tolist(), pose2d_pred[0, k].tolist())
            input()
            mse_each_joint = np.linalg.norm(pose2d_pred - pose2d_gt,
                                            axis=2) * visibility  # b*v x 21
            mse2d_lst += mse_each_joint.sum(axis=0)
            visibility_lst += visibility.sum(axis=0)

            for th_idx in range(len(th2d_lst)):
                PCK2d_lst[th_idx] += np.sum(
                    (mse_each_joint < th2d_lst[th_idx]) * visibility)

            # 3D errors
            for k in range(21):
                print(pose3d_gt[0, k].tolist(), pose3d_pred[0, k].tolist())
            input()
            visibility = visibility.reshape(
                (batch_size, -1, n_joints))  # b x v x 21
            for b in range(batch_size):
                # print(np.sum(visibility[b]), visibility[b].size)
                if np.sum(visibility[b]) >= visibility[b].size * 0.65:
                    n_valid += 1
                    mse_each_joint = np.linalg.norm(
                        pose3d_pred[b].cpu().numpy() - pose3d_gt[b],
                        axis=1)  # 21
                    mse3d_lst += mse_each_joint

                    for th_idx in range(len(th3d_lst)):
                        PCK3d_lst[th_idx] += np.sum(
                            mse_each_joint < th3d_lst[th_idx])

            if i % (len(data_loader) // 5) == 0:
                print("[Evaluation]{}% finished.".format(
                    20 * i // (len(data_loader) // 5)))
            #if i == 10:break
        print('Evaluation spent {:.2f} s\tFPS: {:.1f}'.format(
            time.time() - start_time, infer_time[0] / infer_time[1]))

        mse2d_lst /= visibility_lst
        PCK2d_lst /= visibility_lst.sum()
        mse3d_lst /= n_valid
        PCK3d_lst /= (n_valid * 21)
        plot_performance(PCK2d_lst, th2d_lst, PCK3d_lst, th3d_lst, mse2d_lst,
                         mse3d_lst)

        if not os.path.exists(result):
            os.mkdir(result)
        result_dir = prefix + cfg.EXP_NAME
        if not os.path.exists(result_dir):
            os.mkdir(result_dir)

        np.savetxt(os.path.join(result_dir, 'mse2d_each_joint.txt'),
                   mse2d_lst,
                   fmt='%.4f')
        np.savetxt(os.path.join(result_dir, 'mse3d_each_joint.txt'),
                   mse3d_lst,
                   fmt='%.4f')
        np.savetxt(os.path.join(result_dir, 'PCK2d.txt'),
                   np.stack((th2d_lst, PCK2d_lst)))
        np.savetxt(os.path.join(result_dir, 'PCK3d.txt'),
                   np.stack((th3d_lst, PCK3d_lst)))
Ejemplo n.º 11
0
def main():
    args = parse_args()
    update_config(cfg, args)
    check_config(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
        cfg, is_train=False
    )

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    print("Initilaized.")

    if cfg.TEST.MODEL_FILE:
        model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=True)
    else:
        raise Exception("No weight file. Would you like to test with your hammer?")

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
    model.eval()

    # data_loader, test_dataset = make_test_dataloader(cfg)

    if cfg.MODEL.NAME == 'pose_hourglass':
        transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
            ]
        )
    else:
        # 默认是用这种
        transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ]
        )

    parser = HeatmapParser(cfg)

    print("Load model successfully.")

    ENABLE_CAMERA = 1
    ENABLE_VIDEO = 1

    VIDEO_ROTATE = 0
    
    if ENABLE_CAMERA:
        # 读取视频流
        cap = cv2.VideoCapture(-1)
        ret, image = cap.read()
        x, y = image.shape[0:2]
        print((x, y))
        # 创建视频文件
        fourcc = cv2.VideoWriter_fourcc(*'I420')
        # fourcc = cv2.VideoWriter_fourcc(*'X264')
        out = cv2.VideoWriter('./result.avi', fourcc, 24, (y, x), True)
        while ret:
            ret, image = cap.read()
            if not ret:
                break
            # 实时视频自动禁用scale search
            base_size, center, scale = get_multi_scale_size(
                image, cfg.DATASET.INPUT_SIZE, 1.0, 1.0
            )
            with torch.no_grad():
                final_heatmaps = None
                tags_list = []
                input_size = cfg.DATASET.INPUT_SIZE
                image_resized, center, scale = resize_align_multi_scale(
                    image, input_size, 1.0, 1.0
                )
                image_resized = transforms(image_resized)
                image_resized = image_resized.unsqueeze(0).cuda()

                outputs, heatmaps, tags = get_multi_stage_outputs(
                    cfg, model, image_resized, cfg.TEST.FLIP_TEST,
                    cfg.TEST.PROJECT2IMAGE, base_size
                )

                final_heatmaps, tags_list = aggregate_results(
                    cfg, 1.0, final_heatmaps, tags_list, heatmaps, tags
                )

                final_heatmaps = final_heatmaps / float(len(cfg.TEST.SCALE_FACTOR))
                tags = torch.cat(tags_list, dim=4)
                grouped, scores = parser.parse(
                    final_heatmaps, tags, cfg.TEST.ADJUST, cfg.TEST.REFINE
                )

                final_results = get_final_preds(
                    grouped, center, scale,
                    [final_heatmaps.size(3), final_heatmaps.size(2)]
                )

            detection = save_demo_image(image, final_results, mode=1)

            detection = cv2.cvtColor(detection, cv2.COLOR_BGR2RGB)
            cv2.imshow("Pose Estimation", detection)
            out.write(detection)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

        cap.release()
        out.release()
        os.system("ffmpeg -i result.avi -c:v libx265 camera.mp4")
        cv2.destroyAllWindows()
    elif ENABLE_VIDEO:
        # 读取视频流
        video_name = "./videos/test04.mp4"
        cap = cv2.VideoCapture(video_name)
        # 创建视频文件 
        fourcc = cv2.VideoWriter_fourcc(*'I420')
        out = cv2.VideoWriter('./result.avi', fourcc, 24, (704, 576), True)
        while cap.isOpened():
            ret, image = cap.read()
            if not ret:
                break
            if VIDEO_ROTATE: # 仅适用于扔实心球
                image = cv2.resize(image, (960, 540)).transpose((1, 0, 2))
            # 实时视频自动禁用scale search
            base_size, center, scale = get_multi_scale_size(
                image, cfg.DATASET.INPUT_SIZE, 1.0, 1.0
            )
            with torch.no_grad():
                final_heatmaps = None
                tags_list = []
                input_size = cfg.DATASET.INPUT_SIZE
                image_resized, center, scale = resize_align_multi_scale(
                    image, input_size, 1.0, 1.0
                )
                image_resized = transforms(image_resized)
                image_resized = image_resized.unsqueeze(0).cuda()

                outputs, heatmaps, tags = get_multi_stage_outputs(
                    cfg, model, image_resized, cfg.TEST.FLIP_TEST,
                    cfg.TEST.PROJECT2IMAGE, base_size
                )

                final_heatmaps, tags_list = aggregate_results(
                    cfg, 1.0, final_heatmaps, tags_list, heatmaps, tags
                )

                final_heatmaps = final_heatmaps / float(len(cfg.TEST.SCALE_FACTOR))
                tags = torch.cat(tags_list, dim=4)
                grouped, scores = parser.parse(
                    final_heatmaps, tags, cfg.TEST.ADJUST, cfg.TEST.REFINE
                )

                final_results = get_final_preds(
                    grouped, center, scale,
                    [final_heatmaps.size(3), final_heatmaps.size(2)]
                )

            detection = save_demo_image(image, final_results, mode=1)

            detection = cv2.cvtColor(detection, cv2.COLOR_BGR2RGB)
            cv2.imshow("Pose Estimation", detection)
            out.write(detection)
            # print("frame")
            cv2.waitKey(1)

        cap.release()
        out.release()
        os.system("ffmpeg -i result.avi -c:v libx265 det04.mp4")
        cv2.destroyAllWindows()
    else:
        img_name = "./test.jpg"
        images = cv2.imread(img_name)
        image = images
        # size at scale 1.0
        base_size, center, scale = get_multi_scale_size(
            image, cfg.DATASET.INPUT_SIZE, 1.0, min(cfg.TEST.SCALE_FACTOR)
        )

        with torch.no_grad():
            final_heatmaps = None
            tags_list = []
            print(cfg.TEST.SCALE_FACTOR)
            for idx, s in enumerate(sorted(cfg.TEST.SCALE_FACTOR, reverse=True)):
                input_size = cfg.DATASET.INPUT_SIZE
                image_resized, center, scale = resize_align_multi_scale(
                    image, input_size, s, min(cfg.TEST.SCALE_FACTOR)
                )
                image_resized = transforms(image_resized)
                image_resized = image_resized.unsqueeze(0).cuda()

                outputs, heatmaps, tags = get_multi_stage_outputs(
                    cfg, model, image_resized, cfg.TEST.FLIP_TEST,
                    cfg.TEST.PROJECT2IMAGE, base_size
                )

                final_heatmaps, tags_list = aggregate_results(
                    cfg, s, final_heatmaps, tags_list, heatmaps, tags
                )

            final_heatmaps = final_heatmaps / float(len(cfg.TEST.SCALE_FACTOR))
            tags = torch.cat(tags_list, dim=4)
            grouped, scores = parser.parse(
                final_heatmaps, tags, cfg.TEST.ADJUST, cfg.TEST.REFINE
            )

            final_results = get_final_preds(
                grouped, center, scale,
                [final_heatmaps.size(3), final_heatmaps.size(2)]
            )

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        save_demo_image(image, final_results, file_name="./result.jpg")
Ejemplo n.º 12
0
def main_worker(gpus, ngpus_per_node, args, final_output_dir, tb_log_dir):
    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    #os.environ['CUDA_VISIBLE_DEVICES']=gpus

    # Parallel setting
    print("Use GPU: {} for training".format(gpus))

    update_config(cfg, args)

    #test(cfg, args)

    # logger setting
    logger, _ = setup_logger(final_output_dir, args.rank, 'train')

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    # model initilization
    model = eval(cfg.MODEL.NAME + '.get_pose_net')(cfg, is_train=True)

    # load pretrained model before DDP initialization
    checkpoint_file = os.path.join(final_output_dir, 'model_best.pth.tar')

    if cfg.AUTO_RESUME:
        if os.path.exists(checkpoint_file):
            checkpoint = torch.load(checkpoint_file, map_location='cpu')
            state_dict = checkpoint['state_dict']

            for key in list(state_dict.keys()):
                new_key = key.replace("module.", "")
                state_dict[new_key] = state_dict.pop(key)
            model.load_state_dict(state_dict)
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                checkpoint_file, checkpoint['epoch']))

    elif cfg.MODEL.HRNET_PRETRAINED:
        logger.info("=> loading a pretrained model '{}'".format(
            cfg.MODEL.PRETRAINED))
        checkpoint = torch.load(cfg.MODEL.HRNET_PRETRAINED, map_location='cpu')

        state_dict = checkpoint['state_dict']
        for key in list(state_dict.keys()):
            new_key = key.replace("module.", "")
            state_dict[new_key] = state_dict.pop(key)

        model.load_state_dict(state_dict)

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)
    # copy configuration file
    config_dir = args.cfg
    shutil.copy2(os.path.join(args.cfg), final_output_dir)

    # calculate GFLOPS
    dump_input = torch.rand(
        (1, 3, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[0]))

    logger.info(get_model_summary(model, dump_input, verbose=cfg.VERBOSE))

    #ops, params = get_model_complexity_info(
    #    model, (3, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[0]),
    #    as_strings=True, print_per_layer_stat=True, verbose=True)
    # FP16 SETTING
    if cfg.FP16.ENABLED:
        assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."

    if cfg.FP16.STATIC_LOSS_SCALE != 1.0:
        if not cfg.FP16.ENABLED:
            print(
                "Warning:  if --fp16 is not used, static_loss_scale will be ignored."
            )

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    if cfg.MODEL.SYNC_BN and not cfg.cfg.DISTRIBUTED:
        print(
            'Warning: Sync BatchNorm is only supported in distributed training.'
        )

    # Distributed Computing
    master = True
    if cfg.DISTRIBUTED:  # This block is not available
        args.local_rank += int(gpus[0])
        print('This process is using GPU', args.local_rank)
        device = args.local_rank
        master = device == int(gpus[0])
        dist.init_process_group(backend='nccl')
        if cfg.MODEL.SYNC_BN:
            model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if gpus is not None:
            torch.cuda.set_device(device)
            model.cuda(device)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            # workers = int(workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[device],
                output_device=device,
                find_unused_parameters=True)
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    else:  # implement this block
        gpu_ids = eval('[' + gpus + ']')
        device = gpu_ids[0]
        print('This process is using GPU', str(device))
        model = torch.nn.DataParallel(model, gpu_ids).cuda(device)

    # Prepare loss functions
    criterion = {}
    if cfg.LOSS.WITH_HEATMAP_LOSS:
        criterion['heatmap_loss'] = HeatmapLoss().cuda()
    if cfg.LOSS.WITH_POSE2D_LOSS:
        criterion['pose2d_loss'] = JointsMSELoss().cuda()
    if cfg.LOSS.WITH_BONE_LOSS:
        criterion['bone_loss'] = BoneLengthLoss().cuda()
    if cfg.LOSS.WITH_JOINTANGLE_LOSS:
        criterion['jointangle_loss'] = JointAngleLoss().cuda()

    best_perf = 1e9
    best_model = False
    last_epoch = -1

    # optimizer must be initilized after model initilization
    optimizer = get_optimizer(cfg, model)

    if cfg.FP16.ENABLED:
        optimizer = FP16_Optimizer(
            optimizer,
            static_loss_scale=cfg.FP16.STATIC_LOSS_SCALE,
            dynamic_loss_scale=cfg.FP16.DYNAMIC_LOSS_SCALE,
            verbose=False)

    begin_epoch = cfg.TRAIN.BEGIN_EPOCH

    if not cfg.AUTO_RESUME and cfg.MODEL.HRNET_PRETRAINED:
        optimizer.load_state_dict(checkpoint['optimizer'])

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['loss']
        optimizer.load_state_dict(checkpoint['optimizer'])

        if 'train_global_steps' in checkpoint.keys() and \
        'valid_global_steps' in checkpoint.keys():
            writer_dict['train_global_steps'] = checkpoint[
                'train_global_steps']
            writer_dict['valid_global_steps'] = checkpoint[
                'valid_global_steps']

    if cfg.FP16.ENABLED:
        logger.info("=> Using FP16 mode")
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer.optimizer,
            cfg.TRAIN.LR_STEP,
            cfg.TRAIN.LR_FACTOR,
            last_epoch=begin_epoch)
    elif cfg.TRAIN.LR_SCHEDULE == 'warmup':
        from utils.utils import get_linear_schedule_with_warmup
        lr_scheduler = get_linear_schedule_with_warmup(
            optimizer=optimizer,
            num_warmup_steps=cfg.TRAIN.WARMUP_EPOCHS,
            num_training_steps=cfg.TRAIN.END_EPOCH - cfg.TRAIN.BEGIN_EPOCH,
            last_epoch=begin_epoch)
    elif cfg.TRAIN.LR_SCHEDULE == 'multi_step':
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            cfg.TRAIN.LR_STEP,
            cfg.TRAIN.LR_FACTOR,
            last_epoch=begin_epoch)
    else:
        print('Unknown learning rate schedule!')
        exit()

    # Data loading code
    train_loader_dict = make_dataloader(cfg,
                                        is_train=True,
                                        distributed=cfg.DISTRIBUTED)
    valid_loader_dict = make_dataloader(cfg,
                                        is_train=False,
                                        distributed=cfg.DISTRIBUTED)

    for i, (dataset_name,
            train_loader) in enumerate(train_loader_dict.items()):
        logger.info(
            'Training Loader {}/{}:\n'.format(i + 1, len(train_loader_dict)) +
            str(train_loader.dataset))
    for i, (dataset_name,
            valid_loader) in enumerate(valid_loader_dict.items()):
        logger.info('Validation Loader {}/{}:\n'.format(
            i + 1, len(valid_loader_dict)) + str(valid_loader.dataset))

    #writer_dict['writer'].add_graph(model, (dump_input, ))
    """
    Start training
    """
    start_time = time.time()

    with torch.autograd.set_detect_anomaly(True):
        for epoch in range(begin_epoch + 1, cfg.TRAIN.END_EPOCH + 1):
            epoch_start_time = time.time()
            # shuffle datasets with the sample random seed
            if cfg.DISTRIBUTED:
                for data_loader in train_loader_dict.values():
                    data_loader.sampler.set_epoch(epoch)
            # train for one epoch
            # get_last_lr() returns a list
            logger.info('Start training [{}/{}] lr: {:.4e}'.format(
                epoch, cfg.TRAIN.END_EPOCH - cfg.TRAIN.BEGIN_EPOCH,
                lr_scheduler.get_last_lr()[0]))
            train(cfg,
                  args,
                  master,
                  train_loader_dict,
                  model,
                  criterion,
                  optimizer,
                  epoch,
                  final_output_dir,
                  tb_log_dir,
                  writer_dict,
                  logger,
                  fp16=cfg.FP16.ENABLED,
                  device=device)

            # In PyTorch 1.1.0 and later, you should call `lr_scheduler.step()` after `optimizer.step()`.
            lr_scheduler.step()

            # evaluate on validation set

            if not cfg.WITHOUT_EVAL:
                logger.info('Start evaluating [{}/{}]'.format(
                    epoch, cfg.TRAIN.END_EPOCH - 1))
                with torch.no_grad():
                    recorder = validate(cfg,
                                        args,
                                        master,
                                        valid_loader_dict,
                                        model,
                                        criterion,
                                        final_output_dir,
                                        tb_log_dir,
                                        writer_dict,
                                        logger,
                                        device=device)

                val_total_loss = recorder.avg_total_loss

                best_model = False
                if val_total_loss < best_perf:
                    logger.info(
                        'This epoch yielded a better model with total loss {:.4f} < {:.4f}.'
                        .format(val_total_loss, best_perf))
                    best_perf = val_total_loss
                    best_model = True

            else:
                val_total_loss = 0
                best_model = True

            if master:
                logger.info(
                    '=> saving checkpoint to {}'.format(final_output_dir))
                save_checkpoint(
                    {
                        'epoch': epoch,
                        'model': cfg.EXP_NAME + '.' + cfg.MODEL.NAME,
                        'state_dict': model.state_dict(),
                        'loss': val_total_loss,
                        'optimizer': optimizer.state_dict(),
                        'train_global_steps':
                        writer_dict['train_global_steps'],
                        'valid_global_steps': writer_dict['valid_global_steps']
                    }, best_model, final_output_dir)

            print('\nEpoch {} spent {:.2f} hours\n'.format(
                epoch, (time.time() - epoch_start_time) / 3600))

            #if epoch == 3:break
    if master:
        final_model_state_file = os.path.join(
            final_output_dir, 'final_state{}.pth.tar'.format(gpus))
        logger.info(
            '=> saving final model state to {}'.format(final_model_state_file))
        torch.save(model.state_dict(), final_model_state_file)
        writer_dict['writer'].close()

        print(
            '\n[Training Accomplished] {} epochs spent {:.2f} hours\n'.format(
                cfg.TRAIN.END_EPOCH - begin_epoch + 1,
                (time.time() - start_time) / 3600))
Ejemplo n.º 13
0
def main():
    args = parse_args()
    update_config(cfg, args)
    check_config(cfg)
    pose_dir = prepare_output_dirs(args.outputDir)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'valid')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg,
                                                               is_train=False)

    dump_input = torch.rand(
        (1, 3, cfg.DATASET.INPUT_SIZE, cfg.DATASET.INPUT_SIZE))
    logger.info(get_model_summary(model, dump_input, verbose=cfg.VERBOSE))

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    if cfg.TEST.MODEL_FILE:
        logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
        model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=True)
    else:
        model_state_file = os.path.join(final_output_dir, 'model_best.pth.tar')
        logger.info('=> loading model from {}'.format(model_state_file))
        # model.load_state_dict(torch.load(model_state_file))
        pretrian_model_state = torch.load(model_state_file)

        for name, param in model.state_dict().items():

            model.state_dict()[name].copy_(pretrian_model_state['1.' + name])

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
    model.eval()

    # data_loader, test_dataset = make_test_dataloader(cfg)
    if cfg.MODEL.NAME == 'pose_hourglass':
        transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
        ])
    else:
        transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
        ])

    parser = HeatmapParser(cfg)
    # Loading an video
    vidcap = cv2.VideoCapture(args.videoFile)
    fps = vidcap.get(cv2.CAP_PROP_FPS)
    if fps < args.inferenceFps:
        print('desired inference fps is ' + str(args.inferenceFps) +
              ' but video fps is ' + str(fps))
        exit()
    skip_frame_cnt = round(fps / args.inferenceFps)
    frame_width = int(vidcap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    outcap = cv2.VideoWriter(
        '{}/{}_pose.avi'.format(
            args.outputDir,
            os.path.splitext(os.path.basename(args.videoFile))[0]),
        cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), int(skip_frame_cnt),
        (frame_width, frame_height))

    count = 0
    while vidcap.isOpened():
        total_now = time.time()
        ret, image_bgr = vidcap.read()
        count += 1

        if not ret:
            continue

        if count % skip_frame_cnt != 0:
            continue

        image_debug = image_bgr.copy()
        now = time.time()
        image = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
        # image = image_rgb.cpu().numpy()
        # size at scale 1.0
        base_size, center, scale = get_multi_scale_size(
            image, cfg.DATASET.INPUT_SIZE, 1.0, min(cfg.TEST.SCALE_FACTOR))
        with torch.no_grad():
            final_heatmaps = None
            tags_list = []
            for idx, s in enumerate(sorted(cfg.TEST.SCALE_FACTOR,
                                           reverse=True)):
                input_size = cfg.DATASET.INPUT_SIZE
                image_resized, center, scale = resize_align_multi_scale(
                    image, input_size, s, min(cfg.TEST.SCALE_FACTOR))
                image_resized = transforms(image_resized)
                image_resized = image_resized.unsqueeze(0).cuda()

                outputs, heatmaps, tags = get_multi_stage_outputs(
                    cfg, model, image_resized, cfg.TEST.FLIP_TEST,
                    cfg.TEST.PROJECT2IMAGE, base_size)

                final_heatmaps, tags_list = aggregate_results(
                    cfg, s, final_heatmaps, tags_list, heatmaps, tags)

            final_heatmaps = final_heatmaps / float(len(cfg.TEST.SCALE_FACTOR))
            tags = torch.cat(tags_list, dim=4)
            grouped, scores = parser.parse(final_heatmaps, tags,
                                           cfg.TEST.ADJUST, cfg.TEST.REFINE)

            final_results = get_final_preds(
                grouped, center, scale,
                [final_heatmaps.size(3),
                 final_heatmaps.size(2)])
        for person_joints in final_results:
            for joint in person_joints:
                x, y = int(joint[0]), int(joint[1])
                cv2.circle(image_debug, (x, y), 4, (255, 0, 0), 2)
        then = time.time()
        print("Find person pose in: {} sec".format(then - now))

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
        img_file = os.path.join(pose_dir, 'pose_{:08d}.jpg'.format(count))
        cv2.imwrite(img_file, image_debug)
        outcap.write(image_debug)

    vidcap.release()
    outcap.release()
Ejemplo n.º 14
0
def worker(gpu_id, dataset, indices, cfg, logger, final_output_dir,
           pred_queue):
    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)

    model = eval("models." + cfg.MODEL.NAME + ".get_pose_net")(cfg,
                                                               is_train=False)

    dump_input = torch.rand(
        (1, 3, cfg.DATASET.INPUT_SIZE, cfg.DATASET.INPUT_SIZE))

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    if cfg.TEST.MODEL_FILE:
        logger.info("=> loading model from {}".format(cfg.TEST.MODEL_FILE))
        model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=True)
    else:
        model_state_file = os.path.join(final_output_dir, "model_best.pth.tar")
        logger.info("=> loading model from {}".format(model_state_file))
        model.load_state_dict(torch.load(model_state_file))

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
    model.eval()

    sub_dataset = torch.utils.data.Subset(dataset, indices)
    data_loader = torch.utils.data.DataLoader(sub_dataset,
                                              sampler=None,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=0,
                                              pin_memory=False)

    if cfg.MODEL.NAME == "pose_hourglass":
        transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
        ])
    else:
        transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225]),
        ])

    parser = HeatmapParser(cfg)
    all_preds = []
    all_scores = []

    pbar = tqdm(total=len(sub_dataset)) if cfg.TEST.LOG_PROGRESS else None
    for i, (images, annos) in enumerate(data_loader):
        assert 1 == images.size(0), 'Test batch size should be 1'

        image = images[0].cpu().numpy()
        # size at scale 1.0
        base_size, center, scale = get_multi_scale_size(
            image, cfg.DATASET.INPUT_SIZE, 1.0, min(cfg.TEST.SCALE_FACTOR))

        with torch.no_grad():
            final_heatmaps = None
            tags_list = []
            for idx, s in enumerate(sorted(cfg.TEST.SCALE_FACTOR,
                                           reverse=True)):
                input_size = cfg.DATASET.INPUT_SIZE
                image_resized, center, scale = resize_align_multi_scale(
                    image, input_size, s, min(cfg.TEST.SCALE_FACTOR))
                image_resized = transforms(image_resized)
                image_resized = image_resized.unsqueeze(0).cuda()

                outputs, heatmaps, tags = get_multi_stage_outputs(
                    cfg, model, image_resized, cfg.TEST.FLIP_TEST,
                    cfg.TEST.PROJECT2IMAGE, base_size)

                final_heatmaps, tags_list = aggregate_results(
                    cfg, s, final_heatmaps, tags_list, heatmaps, tags)

            final_heatmaps = final_heatmaps / float(len(cfg.TEST.SCALE_FACTOR))
            tags = torch.cat(tags_list, dim=4)

            visual = False
            if visual:
                visual_heatmap = torch.max(final_skelemaps[0],
                                           dim=0,
                                           keepdim=True)[0]
                visual_heatmap = (visual_heatmap.cpu().numpy().repeat(
                    3, 0).transpose(1, 2, 0))

                mean = [0.485, 0.456, 0.406]
                std = [0.229, 0.224, 0.225]
                visual_img = (image_resized[0].cpu().numpy().transpose(
                    1, 2, 0).astype(np.float32))
                visual_img = visual_img[:, :, ::-1] * np.array(std).reshape(
                    1, 1, 3) + np.array(mean).reshape(1, 1, 3)
                visual_img = visual_img * 255
                test_data = cv2.addWeighted(
                    visual_img.astype(np.float32),
                    0.0,
                    visual_heatmap.astype(np.float32) * 255,
                    1.0,
                    0,
                )
                cv2.imwrite("test_data/{}.jpg".format(i), test_data)

            grouped, scores = parser.parse(final_heatmaps, tags,
                                           cfg.TEST.ADJUST, cfg.TEST.REFINE)

            final_results = get_final_preds(
                grouped, center, scale,
                [final_heatmaps.size(3),
                 final_heatmaps.size(2)])

        if cfg.TEST.LOG_PROGRESS:
            pbar.update()

        data_idx = indices[i]
        img_id = dataset.ids[data_idx]
        file_name = dataset.coco.loadImgs(img_id)[0]["file_name"]

        for idx in range(len(final_results)):
            all_preds.append({
                "keypoints":
                final_results[idx][:, :3].reshape(-1, ).astype(
                    np.float).tolist(),
                "image_id":
                int(file_name[-16:-4]),
                "score":
                float(scores[idx]),
                "category_id":
                1
            })

    if cfg.TEST.LOG_PROGRESS:
        pbar.close()
    pred_queue.put_nowait(all_preds)
Ejemplo n.º 15
0
def main():
    args = parse_args()
    update_config(cfg, args)
    check_config(cfg)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'test'
    )

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg, is_train=False)

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    if cfg.TEST.MODEL_FILE:
        model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=True)
    else:
        model_state_file = os.path.join(final_output_dir, 'model_best.pth.tar')
        model.load_state_dict(torch.load(model_state_file))

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
    model.eval()

    if cfg.MODEL.NAME == 'pose_hourglass':
        transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
            ]
        )
    else:
        transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ]
        )

    HMparser = HeatmapParser(cfg)  # ans, scores

    res_folder = os.path.join(args.outdir, 'results')
    if not os.path.exists(res_folder):
        os.makedirs(res_folder)
    video_name = args.video_path.split('/')[-1].split('.')[0]
    res_file = os.path.join(res_folder, '{}.json'.format(video_name))

    # read frames in video
    stream = cv2.VideoCapture(args.video_path)
    assert stream.isOpened(), 'Cannot capture source'


    # fourcc = int(stream.get(cv2.CAP_PROP_FOURCC))
    fps = stream.get(cv2.CAP_PROP_FPS)
    frameSize = (int(stream.get(cv2.CAP_PROP_FRAME_WIDTH)),
                      int(stream.get(cv2.CAP_PROP_FRAME_HEIGHT)))

    video_dir = os.path.join(args.outdir, 'video', args.data)
    if not os.path.exists(video_dir):
        os.makedirs(video_dir)

    image_dir=os.path.join(args.outdir, 'images', args.data)
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)

    if args.video_format == 'mp4':
        fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
        video_path = os.path.join(video_dir, '{}.mp4'.format(video_name))
    else:
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        video_path = os.path.join(video_dir, '{}.avi'.format(video_name))

    if args.save_video:
        out = cv2.VideoWriter(video_path, fourcc, fps, frameSize)

    num = 0
    annolist = []
    while (True):
        ret, image = stream.read()
        print("num:", num)

        if ret is False:
            break

        all_preds = []
        all_scores = []

        # size at scale 1.0
        base_size, center, scale = get_multi_scale_size(
            image, cfg.DATASET.INPUT_SIZE, 1.0, min(cfg.TEST.SCALE_FACTOR)
        )

        with torch.no_grad():
            final_heatmaps = None
            tags_list = []
            for idx, s in enumerate(sorted(cfg.TEST.SCALE_FACTOR, reverse=True)):
                input_size = cfg.DATASET.INPUT_SIZE
                image_resized, center, scale = resize_align_multi_scale(
                    image, input_size, s, min(cfg.TEST.SCALE_FACTOR)
                )
                image_resized = transforms(image_resized)
                image_resized = image_resized.unsqueeze(0).cuda()

                outputs, heatmaps, tags = get_multi_stage_outputs(
                    cfg, model, image_resized, cfg.TEST.FLIP_TEST,
                    cfg.TEST.PROJECT2IMAGE, base_size
                )

                final_heatmaps, tags_list = aggregate_results(
                    cfg, s, final_heatmaps, tags_list, heatmaps, tags
                )

            final_heatmaps = final_heatmaps / float(len(cfg.TEST.SCALE_FACTOR))
            tags = torch.cat(tags_list, dim=4)
            grouped, scores = HMparser.parse(
                final_heatmaps, tags, cfg.TEST.ADJUST, cfg.TEST.REFINE
            )

            final_results = get_final_preds(  # joints for all persons in a image
                grouped, center, scale,
                [final_heatmaps.size(3), final_heatmaps.size(2)]
            )

        image=draw_image(image, final_results, dataset=args.data)
        all_preds.append(final_results)
        all_scores.append(scores)

        img_id = num
        num += 1
        file_name = '{}.jpg'.format(str(img_id).zfill(6))
        annorect = person_result(all_preds, scores, img_id)

        annolist.append({
            'annorect': annorect,
            'ignore_regions': [],
            'image': [{'name': file_name}]
        })
        # print(annorect)

        if args.save_video:
            out.write(image)

        if args.save_img:
            img_path = os.path.join(image_dir, file_name)
            cv2.imwrite(img_path, image)

    final_results = {'annolist': annolist}
    with open(res_file, 'w') as f:
        json.dump(final_results, f)
    print('=> create test json finished!')

    # print('=> finished! you can check the output video on {}'.format(save_path))
    stream.release()
    # out.release()
    cv2.destroyAllWindows()
Ejemplo n.º 16
0
def main():
    args = parse_args()
    update_config(cfg, args)
    check_config(cfg)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'valid')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg,
                                                               is_train=False)

    dump_input = torch.rand(
        (1, 3, cfg.DATASET.INPUT_SIZE, cfg.DATASET.INPUT_SIZE))
    logger.info(get_model_summary(model, dump_input, verbose=cfg.VERBOSE))

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    if cfg.TEST.MODEL_FILE:
        logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
        model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=True)
    else:
        model_state_file = os.path.join(final_output_dir, 'model_best.pth.tar')
        logger.info('=> loading model from {}'.format(model_state_file))
        model.load_state_dict(torch.load(model_state_file))

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
    model.eval()
    if cfg.MODEL.NAME == 'pose_hourglass':
        transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
        ])
    else:
        transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
        ])
    transforms_pre = torchvision.transforms.Compose([
        ToNumpy(),
    ])
    # iterate over all datasets
    datasets_root_path = "/media/jld/DATOS_JLD/datasets"
    datasets = ["cityscapes", "kitti", "tsinghua"]
    # testing sets from cityscapes and kitti does not have groundtruth --> processing not required
    datasplits = [["train", "val"], ["train"], ["train", "val", "test"]]
    keypoints_output_root_path = "/media/jld/DATOS_JLD/git-repos/paper-revista-keypoints/results"
    model_name = osp.basename(
        cfg.TEST.MODEL_FILE).split('.')[0]  # Model name + configuration
    for dsid, dataset in enumerate(datasets):
        dataset_root_path = osp.join(datasets_root_path, dataset)
        output_root_path = osp.join(keypoints_output_root_path, dataset)
        for datasplit in datasplits[dsid]:
            loggur.info(f"Processing split {datasplit} of {dataset}")
            input_img_dir = osp.join(dataset_root_path, datasplit)
            output_kps_json_dir = osp.join(output_root_path, datasplit,
                                           model_name)
            loggur.info(f"Input image dir: {input_img_dir}")
            loggur.info(f"Output pose JSON dir: {output_kps_json_dir}")
            # test_dataset = torchvision.datasets.ImageFolder("/media/jld/DATOS_JLD/git-repos/paper-revista-keypoints/test_images/", transform=transforms_pre)
            test_dataset = dsjld.BaseDataset(input_img_dir,
                                             output_kps_json_dir,
                                             transform=transforms_pre)
            test_dataset.generate_io_samples_pairs()
            # Stablish weight of keypoints scores (like openpifpaf in https://github.com/vita-epfl/openpifpaf/blob/master/openpifpaf/decoder/annotation.py#L44)
            n_keypoints = 17
            kps_score_weights = numpy.ones((17, ))
            kps_score_weights[:3] = 3.0
            # Normalize weights to sum 1
            kps_score_weights /= numpy.sum(kps_score_weights)
            data_loader = torch.utils.data.DataLoader(test_dataset,
                                                      batch_size=1,
                                                      shuffle=False,
                                                      num_workers=0,
                                                      pin_memory=False)
            parser = HeatmapParser(cfg)
            all_preds = []
            all_scores = []

            pbar = tqdm(
                total=len(test_dataset))  # if cfg.TEST.LOG_PROGRESS else None
            for i, (img, imgidx) in enumerate(data_loader):
                assert 1 == img.size(0), 'Test batch size should be 1'

                img = img[0].cpu().numpy()
                # size at scale 1.0
                base_size, center, scale = get_multi_scale_size(
                    img, cfg.DATASET.INPUT_SIZE, 1.0,
                    min(cfg.TEST.SCALE_FACTOR))

                with torch.no_grad():
                    final_heatmaps = None
                    tags_list = []
                    for idx, s in enumerate(
                            sorted(cfg.TEST.SCALE_FACTOR, reverse=True)):
                        input_size = cfg.DATASET.INPUT_SIZE
                        image_resized, center, scale = resize_align_multi_scale(
                            img, input_size, s, min(cfg.TEST.SCALE_FACTOR))
                        image_resized = transforms(image_resized)
                        image_resized = image_resized.unsqueeze(0).cuda()

                        outputs, heatmaps, tags = get_multi_stage_outputs(
                            cfg, model, image_resized, cfg.TEST.FLIP_TEST,
                            cfg.TEST.PROJECT2IMAGE, base_size)

                        final_heatmaps, tags_list = aggregate_results(
                            cfg, s, final_heatmaps, tags_list, heatmaps, tags)

                    final_heatmaps = final_heatmaps / float(
                        len(cfg.TEST.SCALE_FACTOR))
                    tags = torch.cat(tags_list, dim=4)
                    grouped, scores = parser.parse(final_heatmaps, tags,
                                                   cfg.TEST.ADJUST,
                                                   cfg.TEST.REFINE)

                    final_results = get_final_preds(
                        grouped, center, scale,
                        [final_heatmaps.size(3),
                         final_heatmaps.size(2)])

                # if cfg.TEST.LOG_PROGRESS:
                pbar.update()
                # Save all keypoints in a JSON dict
                final_json_results = []
                for kps in final_results:
                    kpsdict = {}
                    x = kps[:, 0]
                    y = kps[:, 1]
                    kps_scores = kps[:, 2]
                    kpsdict['keypoints'] = kps[:, 0:3].tolist()
                    # bounding box by means of minmax approach (without zero elements)
                    xmin = numpy.float64(numpy.min(x[numpy.nonzero(x)]))
                    xmax = numpy.float64(numpy.max(x))
                    width = numpy.float64(xmax - xmin)
                    ymin = numpy.float64(numpy.min(y[numpy.nonzero(y)]))
                    ymax = numpy.float64(numpy.max(y))
                    height = numpy.float64(ymax - ymin)
                    kpsdict['bbox'] = [xmin, ymin, width, height]
                    # Calculate pose score as a weighted mean of keypoints scores
                    kpsdict['score'] = numpy.float64(
                        numpy.sum(kps_score_weights *
                                  numpy.sort(kps_scores)[::-1]))
                    final_json_results.append(kpsdict)

                with open(test_dataset.output_json_files_list[imgidx],
                          "w") as f:
                    json.dump(final_json_results, f)

                all_preds.append(final_results)
                all_scores.append(scores)

            if cfg.TEST.LOG_PROGRESS:
                pbar.close()
Ejemplo n.º 17
0
def main():
    args = parse_args()
    update_config(cfg, args)
    check_config(cfg)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'valid')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg,
                                                               is_train=False)

    dump_input = torch.rand(
        (1, 3, cfg.DATASET.INPUT_SIZE, cfg.DATASET.INPUT_SIZE))
    logger.info(get_model_summary(model, dump_input, verbose=cfg.VERBOSE))

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    if cfg.TEST.MODEL_FILE:
        logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
        model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=True)
    else:
        model_state_file = os.path.join(final_output_dir, 'model_best.pth.tar')
        logger.info('=> loading model from {}'.format(model_state_file))
        model.load_state_dict(torch.load(model_state_file))

    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
    model.eval()

    data_loader, test_dataset = make_test_dataloader(cfg)

    if cfg.MODEL.NAME == 'pose_hourglass':
        transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
        ])
    else:
        transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
        ])

    parser = HeatmapParser(cfg)
    all_preds = []
    all_scores = []

    # pbar = tqdm(total=len(test_dataset)) if cfg.TEST.LOG_PROGRESS else None
    pbar = tqdm(total=len(test_dataset))
    for i, (images, annos) in enumerate(data_loader):
        assert 1 == images.size(0), 'Test batch size should be 1'

        image = images[0].cpu().numpy()
        # size at scale 1.0
        base_size, center, scale = get_multi_scale_size(
            image, cfg.DATASET.INPUT_SIZE, 1.0, min(cfg.TEST.SCALE_FACTOR))

        with torch.no_grad():
            final_heatmaps = None
            tags_list = []
            for idx, s in enumerate(sorted(cfg.TEST.SCALE_FACTOR,
                                           reverse=True)):
                input_size = cfg.DATASET.INPUT_SIZE
                image_resized, center, scale = resize_align_multi_scale(
                    image, input_size, s, min(cfg.TEST.SCALE_FACTOR))
                image_resized = transforms(image_resized)
                image_resized = image_resized.unsqueeze(0).cuda()

                outputs, heatmaps, tags = get_multi_stage_outputs(
                    cfg, model, image_resized, cfg.TEST.FLIP_TEST,
                    cfg.TEST.PROJECT2IMAGE, base_size)

                final_heatmaps, tags_list = aggregate_results(
                    cfg, s, final_heatmaps, tags_list, heatmaps, tags)

            final_heatmaps = final_heatmaps / float(len(cfg.TEST.SCALE_FACTOR))
            tags = torch.cat(tags_list, dim=4)
            grouped, scores = parser.parse(final_heatmaps, tags,
                                           cfg.TEST.ADJUST, cfg.TEST.REFINE)

            final_results = get_final_preds(
                grouped, center, scale,
                [final_heatmaps.size(3),
                 final_heatmaps.size(2)])
            if cfg.RESCORE.USE:
                try:
                    scores = rescore_valid(cfg, final_results, scores)
                except:
                    print("got one.")
        # if cfg.TEST.LOG_PROGRESS:
        #     pbar.update()
        pbar.update()

        if i % cfg.PRINT_FREQ == 0:
            prefix = '{}_{}'.format(
                os.path.join(final_output_dir, 'result_valid'), i)
            # logger.info('=> write {}'.format(prefix))
            save_valid_image(image,
                             final_results,
                             '{}.jpg'.format(prefix),
                             dataset=test_dataset.name)
            # for scale_idx in range(len(outputs)):
            #     prefix_scale = prefix + '_output_{}'.format(
            #         # cfg.DATASET.OUTPUT_SIZE[scale_idx]
            #         scale_idx
            #     )
            #     save_debug_images(
            #         cfg, images, None, None,
            #         outputs[scale_idx], prefix_scale
            #     )
        all_preds.append(final_results)
        all_scores.append(scores)

    if cfg.TEST.LOG_PROGRESS:
        pbar.close()

    name_values, _ = test_dataset.evaluate(cfg, all_preds, all_scores,
                                           final_output_dir)

    if isinstance(name_values, list):
        for name_value in name_values:
            _print_name_value(logger, name_value, cfg.MODEL.NAME)
    else:
        _print_name_value(logger, name_values, cfg.MODEL.NAME)
Ejemplo n.º 18
0
def worker(gpu_id, img_list, cfg, logger, final_output_dir, save_dir, pred_queue):
    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)

    model = eval("models." + cfg.MODEL.NAME + ".get_pose_net")(cfg, is_train=False)

    dump_input = torch.rand((1, 3, cfg.DATASET.INPUT_SIZE, cfg.DATASET.INPUT_SIZE))

    if cfg.FP16.ENABLED:
        model = network_to_half(model)

    if cfg.TEST.MODEL_FILE:
        logger.info("=> loading model from {}".format(cfg.TEST.MODEL_FILE))
        model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=True)
    else:
        model_state_file = os.path.join(final_output_dir, "model_best.pth.tar")
        logger.info("=> loading model from {}".format(model_state_file))
        model.load_state_dict(torch.load(model_state_file))
        
    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
    model.eval()

    if cfg.MODEL.NAME == "pose_hourglass":
        transforms = torchvision.transforms.Compose(
            [torchvision.transforms.ToTensor(),]
        )
    else:
        transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

    parser = HeatmapParser(cfg)
    all_preds = []
    all_scores = []

    pbar = tqdm(total=len(img_list)) if cfg.TEST.LOG_PROGRESS else None
    for i, img_path in enumerate(img_list):

        image_name = img_path.split('/')[-1].split('.')[0]
        image = cv2.imread(
            img_path, 
            cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION
        )

        # size at scale 1.0
        base_size, center, scale = get_multi_scale_size(
            image, cfg.DATASET.INPUT_SIZE, 1.0, min(cfg.TEST.SCALE_FACTOR)
        )

        with torch.no_grad():
            final_heatmaps = None
            tags_list = []
            for idx, s in enumerate(sorted(cfg.TEST.SCALE_FACTOR, reverse=True)):
                input_size = cfg.DATASET.INPUT_SIZE
                image_resized, center, scale = resize_align_multi_scale(
                    image, input_size, s, min(cfg.TEST.SCALE_FACTOR)
                )
                image_resized = transforms(image_resized)
                image_resized = image_resized.unsqueeze(0).cuda()

                outputs, heatmaps, tags = get_multi_stage_outputs(
                    cfg, model, image_resized, cfg.TEST.FLIP_TEST,
                    cfg.TEST.PROJECT2IMAGE, base_size
                )

                final_heatmaps, tags_list = aggregate_results(
                    cfg, s, final_heatmaps, tags_list, heatmaps, tags
                )

            final_heatmaps = final_heatmaps / float(len(cfg.TEST.SCALE_FACTOR))
            tags = torch.cat(tags_list, dim=4)

            grouped, scores = parser.parse(
                final_heatmaps, tags, cfg.TEST.ADJUST, cfg.TEST.REFINE
            )

            final_results = get_final_preds(
                grouped, center, scale,
                [final_heatmaps.size(3), final_heatmaps.size(2)]
            )

        if cfg.TEST.LOG_PROGRESS:
            pbar.update()

        for idx in range(len(final_results)):
            all_preds.append({
                "keypoints": final_results[idx][:,:3].reshape(-1,).astype(np.float).tolist(),
                "image_name": image_name,
                "score": float(scores[idx]),
                "category_id": 1
            })
        
        skeleton_map = draw_skeleton(image, np.array(final_results))
        cv2.imwrite(
            os.path.join(save_dir, "{}.jpg".format(image_name)),
            skeleton_map
            )

    if cfg.TEST.LOG_PROGRESS:
        pbar.close()
    pred_queue.put_nowait(all_preds)