def load_dataset() -> MocapDataset:
    """
    加载数据集
    Returns: dataset
    """
    print('Loading dataset...')
    dataset_path = 'data/data_3d_' + args.dataset + '.npz'
    if args.dataset == 'h36m':
        # Human3.6M的3d关键点数据集
        from common.h36m_dataset import Human36mDataset
        dataset_ = Human36mDataset(dataset_path)
    elif args.dataset.startswith('humaneva'):
        # Human-eva的3d关键点数据集
        from common.humaneva_dataset import HumanEvaDataset
        dataset_ = HumanEvaDataset(dataset_path)
    elif args.dataset.startswith('custom'):
        # 自定义数据集是2d关键点集,用于预测3d关键点
        from common.custom_dataset import CustomDataset
        dataset_ = CustomDataset('data/data_2d_' + args.dataset + '_' +
                                 args.keypoints + '.npz')
    else:
        raise KeyError('Invalid dataset')

    return dataset_
Ejemplo n.º 2
0
                positions /= 1000  # Meters instead of millimeters
                output[subject][canonical_name] = positions.astype('float32')

        print('Saving...')
        np.savez_compressed(output_filename, positions_3d=output)

        print('Done.')

    else:
        print('Please specify the dataset source')
        exit(0)

    # Create 2D pose file
    print('')
    print('Computing ground-truth 2D poses...')
    dataset = Human36mDataset(output_filename + '.npz')
    output_2d_poses = {}
    for subject in dataset.subjects():
        output_2d_poses[subject] = {}
        for action in dataset[subject].keys():
            anim = dataset[subject][action]

            positions_2d = []
            for cam in anim['cameras']:
                pos_3d = world_to_camera(anim['positions'],
                                         R=cam['orientation'],
                                         t=cam['translation'])
                pos_2d = wrap(project_to_2d, True, pos_3d, cam['intrinsic'])
                pos_2d_pixel_space = image_coordinates(pos_2d,
                                                       w=cam['res_w'],
                                                       h=cam['res_h'])
Ejemplo n.º 3
0
                predicted_3d_pos[1, :, :, 0] *= -1
                predicted_3d_pos[1, :, joints_left +
                                 joints_right] = predicted_3d_pos[1, :, joints_right + joints_left]
                predicted_3d_pos = torch.mean(
                    predicted_3d_pos, dim=0, keepdim=True)

            if return_predictions:
                return predicted_3d_pos.squeeze(0).cpu().numpy()


time0 = ckpt_time()
print('Loading 3D dataset...')
# input your own datapath
dataset_path = '../data/videopose/data_3d_' + \
    args.dataset + '.npz'  # dataset 'h36m'
dataset = Human36mDataset(dataset_path)  # '/path/to/data_3d_h36m.npz'

ckpt, time1 = ckpt_time(time0)
print('load 3D dataset spend {:2f} second'.format(ckpt))

# according to output name,generate some format. we use detectron
metadata = suggest_metadata('detectron_pt_coco')
print('Loading 2D detections keypoints ...')

if args.input_npz:
    #if already exist keypoint npz file
    npz = np.load(args.input_npz)
    keypoints = npz['kpts']
else:
    # crate kpts by alphapose
    from Alphapose.gene_npz import handle_video
Ejemplo n.º 4
0



try:
    # Create checkpoint directory if it does not exist
    os.makedirs(args.checkpoint)
except OSError as e:
    if e.errno != errno.EEXIST:
        raise RuntimeError('Unable to create checkpoint directory:', args.checkpoint)

print('Loading dataset...')
dataset_path = 'data/data_3d_' + args.dataset + '.npz'
if args.dataset == 'h36m':
    from common.h36m_dataset import Human36mDataset
    dataset = Human36mDataset(dataset_path)
elif args.dataset.startswith('humaneva'):
    from common.humaneva_dataset import HumanEvaDataset
    dataset = HumanEvaDataset(dataset_path)
elif args.dataset.startswith('custom'):
    from common.custom_dataset import CustomDataset
    dataset = CustomDataset('data/data_2d_' + args.dataset + '_' + args.keypoints + '.npz')
else:
    raise KeyError('Invalid dataset')

print('Preparing data...')
for subject in dataset.subjects():
    for action in dataset[subject].keys():
        anim = dataset[subject][action]
        
        if 'positions' in anim:
Ejemplo n.º 5
0
def main(args):
    print('==> Using settings {}'.format(args))

    print('==> Loading dataset...')
    dataset_path = path.join('data', 'data_3d_' + args.dataset + '.npz')
    if args.dataset == 'h36m':
        from common.h36m_dataset import Human36mDataset, TRAIN_SUBJECTS, TEST_SUBJECTS
        dataset = Human36mDataset(dataset_path)
        subjects_train = TRAIN_SUBJECTS
        subjects_test = TEST_SUBJECTS
    else:
        raise KeyError('Invalid dataset')

    print('==> Preparing data...')
    dataset = read_3d_data(dataset)

    print('==> Loading 2D detections...')
    keypoints = create_2d_data(
        path.join('data',
                  'data_2d_' + args.dataset + '_' + args.keypoints + '.npz'),
        dataset)

    action_filter = None if args.actions == '*' else args.actions.split(',')
    if action_filter is not None:
        action_filter = map(lambda x: dataset.define_actions(x)[0],
                            action_filter)
        print('==> Selected actions: {}'.format(action_filter))

    stride = args.downsample
    cudnn.benchmark = True
    device = torch.device("cuda")

    # Create model
    print("==> Creating model...")

    adj, adj_mutual = adj_mx_from_skeleton(
        dataset.skeleton())  ##multi-person adj-matrix
    model_pos = SemGCN(adj,
                       adj_mutual,
                       args.hid_dim,
                       num_layers=args.num_layers,
                       nodes_group=dataset.skeleton().joints_group()
                       if args.non_local else None).to(device)
    print("==> Total parameters: {:.2f}M".format(
        sum(p.numel() for p in model_pos.parameters()) / 1000000.0))

    criterion = nn.MSELoss(reduction='mean').to(device)
    optimizer = torch.optim.Adam(model_pos.parameters(), lr=args.lr)

    # Optionally resume from a checkpoint
    if args.resume or args.evaluate:
        ckpt_path = path.join(args.resume if args.resume else args.evaluate)

        if path.isfile(ckpt_path):
            print("=> Loading checkpoint '{}'".format(ckpt_path))
            ckpt = torch.load(ckpt_path)
            start_epoch = ckpt['epoch']
            error_best = ckpt['error']
            glob_step = ckpt['step']
            lr_now = ckpt['lr']
            model_pos.load_state_dict(ckpt['state_dict'])
            optimizer.load_state_dict(ckpt['optimizer'])
            print("=> Loaded checkpoint (Epoch: {} | Error: {})".format(
                start_epoch, error_best))

            if args.resume:
                ckpt_dir_path = path.dirname(ckpt_path)
                logger = Logger(path.join(ckpt_dir_path, 'log.txt'),
                                resume=True)
        else:
            raise RuntimeError(
                "=> No checkpoint found at '{}'".format(ckpt_path))
    else:
        start_epoch = 0
        error_best = None
        glob_step = 0
        lr_now = args.lr
        ckpt_dir_path = path.join(
            args.checkpoint,
            datetime.datetime.now().isoformat() +
            "_l_%04d_hid_%04d_e_%04d_non_local_%d" %
            (args.num_layers, args.hid_dim, args.epochs, args.non_local))

        if not path.exists(ckpt_dir_path):
            os.makedirs(ckpt_dir_path)
            print('=> Making checkpoint dir: {}'.format(ckpt_dir_path))

        logger = Logger(os.path.join(ckpt_dir_path, 'log.txt'))
        logger.set_names(
            ['epoch', 'lr', 'loss_train', 'error_eval_p1', 'error_eval_p2'])

    if args.evaluate:
        print('==> Evaluating...')

        if action_filter is None:
            action_filter = dataset.define_actions()

        errors_p1 = np.zeros(len(action_filter))
        errors_p2 = np.zeros(len(action_filter))

        for i, action in enumerate(action_filter):
            poses_valid, poses_valid_2d, actions_valid = fetch(
                subjects_test, dataset, keypoints, [action], stride)
            valid_loader = DataLoader(PoseGenerator(poses_valid,
                                                    poses_valid_2d,
                                                    actions_valid),
                                      batch_size=args.batch_size,
                                      shuffle=False,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
            errors_p1[i], errors_p2[i] = evaluate(valid_loader, model_pos,
                                                  device)

        print('Protocol #1   (MPJPE) action-wise average: {:.2f} (mm)'.format(
            np.mean(errors_p1).item()))
        print('Protocol #2 (P-MPJPE) action-wise average: {:.2f} (mm)'.format(
            np.mean(errors_p2).item()))
        exit(0)

    poses_train, poses_train_2d, actions_train = fetch(subjects_train, dataset,
                                                       keypoints,
                                                       action_filter, stride)
    train_loader = DataLoader(PoseGenerator(poses_train, poses_train_2d,
                                            actions_train),
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              pin_memory=True)

    poses_valid, poses_valid_2d, actions_valid = fetch(subjects_test, dataset,
                                                       keypoints,
                                                       action_filter, stride)
    valid_loader = DataLoader(PoseGenerator(poses_valid, poses_valid_2d,
                                            actions_valid),
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_workers,
                              pin_memory=True)

    for epoch in range(start_epoch, args.epochs):
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr_now))

        # Train for one epoch
        epoch_loss, lr_now, glob_step = train(train_loader,
                                              model_pos,
                                              criterion,
                                              optimizer,
                                              device,
                                              args.lr,
                                              lr_now,
                                              glob_step,
                                              args.lr_decay,
                                              args.lr_gamma,
                                              max_norm=args.max_norm)

        # Evaluate
        error_eval_p1, error_eval_p2 = evaluate(valid_loader, model_pos,
                                                device)

        # Update log file
        logger.append(
            [epoch + 1, lr_now, epoch_loss, error_eval_p1, error_eval_p2])

        # Save checkpoint
        if error_best is None or error_best > error_eval_p1:
            error_best = error_eval_p1
            save_ckpt(
                {
                    'epoch': epoch + 1,
                    'lr': lr_now,
                    'step': glob_step,
                    'state_dict': model_pos.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'error': error_eval_p1
                },
                ckpt_dir_path,
                suffix='best')

        if (epoch + 1) % args.snapshot == 0:
            save_ckpt(
                {
                    'epoch': epoch + 1,
                    'lr': lr_now,
                    'step': glob_step,
                    'state_dict': model_pos.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'error': error_eval_p1
                }, ckpt_dir_path)

    logger.close()
    logger.plot(['loss_train', 'error_eval_p1'])
    savefig(path.join(ckpt_dir_path, 'log.eps'))

    return
Ejemplo n.º 6
0
def main(args):
    print('==> Using settings {}'.format(args))

    convm = torch.zeros(3, 17, 17, dtype=torch.float)

    print('==> Loading dataset...')
    dataset_path = path.join('data', 'data_3d_' + args.dataset + '.npz')
    if args.dataset == 'h36m':
        from common.h36m_dataset import Human36mDataset
        dataset = Human36mDataset(dataset_path)
    else:
        raise KeyError('Invalid dataset')

    print('==> Preparing data...')
    dataset = read_3d_data(dataset)

    print('==> Loading 2D detections...')
    keypoints = create_2d_data(
        path.join('data',
                  'data_2d_' + args.dataset + '_' + args.keypoints + '.npz'),
        dataset)

    cudnn.benchmark = True
    device = torch.device("cuda")

    # Create model
    print("==> Creating model...")

    if args.architecture == 'linear':
        from models.linear_model import LinearModel, init_weights
        num_joints = dataset.skeleton().num_joints()
        model_pos = LinearModel(num_joints * 2,
                                (num_joints - 1) * 3).to(device)
        model_pos.apply(init_weights)
    elif args.architecture == 'gcn':
        from models.sem_gcn import SemGCN
        from common.graph_utils import adj_mx_from_skeleton
        p_dropout = (None if args.dropout == 0.0 else args.dropout)
        adj = adj_mx_from_skeleton(dataset.skeleton())
        model_pos = SemGCN(convm,
                           adj,
                           args.hid_dim,
                           num_layers=args.num_layers,
                           p_dropout=p_dropout,
                           nodes_group=dataset.skeleton().joints_group()
                           if args.non_local else None).to(device)
    else:
        raise KeyError('Invalid model architecture')

    print("==> Total parameters: {:.2f}M".format(
        sum(p.numel() for p in model_pos.parameters()) / 1000000.0))

    # Resume from a checkpoint
    ckpt_path = args.evaluate

    if path.isfile(ckpt_path):
        print("==> Loading checkpoint '{}'".format(ckpt_path))
        ckpt = torch.load(ckpt_path)
        start_epoch = ckpt['epoch']
        error_best = ckpt['error']
        model_pos.load_state_dict(ckpt['state_dict'])
        print("==> Loaded checkpoint (Epoch: {} | Error: {})".format(
            start_epoch, error_best))
    else:
        raise RuntimeError("==> No checkpoint found at '{}'".format(ckpt_path))

    print('==> Rendering...')

    poses_2d = keypoints[args.viz_subject][args.viz_action]
    out_poses_2d = poses_2d[args.viz_camera]
    out_actions = [args.viz_camera] * out_poses_2d.shape[0]

    poses_3d = dataset[args.viz_subject][args.viz_action]['positions_3d']
    assert len(poses_3d) == len(poses_2d), 'Camera count mismatch'
    out_poses_3d = poses_3d[args.viz_camera]

    ground_truth = dataset[args.viz_subject][args.viz_action]['positions_3d'][
        args.viz_camera].copy()

    input_keypoints = out_poses_2d.copy()
    render_loader = DataLoader(PoseGenerator([out_poses_3d], [out_poses_2d],
                                             [out_actions]),
                               batch_size=args.batch_size,
                               shuffle=False,
                               num_workers=args.num_workers,
                               pin_memory=True)

    prediction = evaluate(render_loader, model_pos, device,
                          args.architecture)[0]

    # Invert camera transformation
    cam = dataset.cameras()[args.viz_subject][args.viz_camera]
    prediction = camera_to_world(prediction, R=cam['orientation'], t=0)
    prediction[:, :, 2] -= np.min(prediction[:, :, 2])
    ground_truth = camera_to_world(ground_truth, R=cam['orientation'], t=0)
    ground_truth[:, :, 2] -= np.min(ground_truth[:, :, 2])

    anim_output = {'Regression': prediction, 'Ground truth': ground_truth}
    input_keypoints = image_coordinates(input_keypoints[..., :2],
                                        w=cam['res_w'],
                                        h=cam['res_h'])
    render_animation(input_keypoints,
                     anim_output,
                     dataset.skeleton(),
                     dataset.fps(),
                     args.viz_bitrate,
                     cam['azimuth'],
                     args.viz_output,
                     limit=args.viz_limit,
                     downsample=args.viz_downsample,
                     size=args.viz_size,
                     input_video_path=args.viz_video,
                     viewport=(cam['res_w'], cam['res_h']),
                     input_video_skip=args.viz_skip)
Ejemplo n.º 7
0
import time


# record time
def ckpt_time(ckpt=None):
    if not ckpt:
        return time.time()
    else:
        return time.time() - float(ckpt), time.time()


time0 = ckpt_time()
print('Loading 3D dataset...')
dataset_path = 'data/data_3d_' + args.dataset + '.npz'  #  dataset 'h36m'
from common.h36m_dataset import Human36mDataset
dataset = Human36mDataset(dataset_path)  #'data/data_3d_h36m.npz'

ckpt, time1 = ckpt_time(time0)
print('load 3D dataset spend {:2f} second'.format(ckpt))

# according to output name,generate some format. we use detectron
from data.data_utils import suggest_metadata, suggest_pose_importer
metadata = suggest_metadata('detectron_pt_coco')
print('Loading 2D detections keypoints ...')

if args.input_npz:
    #如果already exist keypoint npz file
    npz = np.load(args.input_npz)
    keypoints = npz['kpts']
else:
    # crate kpts by alphapose
Ejemplo n.º 8
0
def main(args):
    print('==> Using settings {}'.format(args))

    print('==> Loading dataset...')
    if args.dataset == 'h36m':
        dataset_path = path.join('data', 'data_3d_' + args.dataset + '.npz')
        from common.h36m_dataset import Human36mDataset, TRAIN_SUBJECTS, TEST_SUBJECTS
        dataset = Human36mDataset(dataset_path)
        subjects_train = TRAIN_SUBJECTS
        subjects_test = TEST_SUBJECTS

        print('==> Preparing data ' + args.dataset + "...")
        dataset = read_3d_data(dataset)
        adj = adj_mx_from_skeleton(dataset.skeleton())
        nodes_group = dataset.skeleton().joints_group() if args.non_local else None

        print('==> Loading 2D detections...')
        keypoints = create_2d_data(path.join('data', 'data_2d_' + args.dataset + '_' + args.keypoints + '.npz'), dataset)
        action_filter = None if args.actions == '*' else args.actions.split(',')

        stride = args.downsample
        if action_filter is not None:
            action_filter = map(lambda x: dataset.define_actions(x)[0], action_filter)
            print('==> Selected actions: {}'.format(action_filter))

        if not args.evaluate:
            print('==> Build DataLoader...')
            poses_train, poses_train_2d, actions_train = fetch(subjects_train, dataset, keypoints, action_filter, stride)
            train_loader = DataLoader(PoseGenerator(poses_train, poses_train_2d, actions_train), batch_size=args.batch_size,
                                      shuffle=True, num_workers=args.num_workers, pin_memory=True)
            poses_valid, poses_valid_2d, actions_valid = fetch(subjects_test, dataset, keypoints, action_filter, stride)
            valid_loader = DataLoader(PoseGenerator(poses_valid, poses_valid_2d, actions_valid), batch_size=args.batch_size,
                                      shuffle=False, num_workers=args.num_workers, pin_memory=True)
    elif "synmit" in args.dataset:
        dataset_path = args.dataset_path
        from common.synmit_dataset import SynDataset17, SynDataset17_h36m
        if "h36m" in args.dataset:
            train_dataset = SynDataset17_h36m(dataset_path, image_set="train")
            valid_dataset = SynDataset17_h36m(dataset_path, image_set="val")
        else:
            train_dataset = SynDataset17(dataset_path, image_set="train")
            valid_dataset = SynDataset17(dataset_path, image_set="val")
        dataset = train_dataset
        adj = adj_mx_from_edges(dataset.jointcount(), dataset.edge(), sparse=False)
        nodes_group = dataset.nodesgroup() if args.non_local else None
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
                                  shuffle=True, num_workers=args.num_workers, pin_memory=True)
        valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size,
                                  shuffle=False, num_workers=args.num_workers, pin_memory=True)
    else:
        raise KeyError('Invalid dataset')


    cudnn.benchmark = True
    device = torch.device("cuda")

    # Create model
    print("==> Creating model...")
    p_dropout = (None if args.dropout == 0.0 else args.dropout)
    view_count = dataset.view_count if args.multiview else None
    model_pos = MultiviewSemGCN(adj, args.hid_dim, coords_dim=(3, 1), num_layers=args.num_layers, p_dropout=p_dropout,
                                view_count=view_count, nodes_group=nodes_group).to(device)
    print("==> Total parameters: {:.2f}M".format(sum(p.numel() for p in model_pos.parameters()) / 1000000.0))

    criterion = nn.MSELoss(reduction='mean').to(device)
    optimizer = torch.optim.Adam(model_pos.parameters(), lr=args.lr)

    # Optionally resume from a checkpoint
    if args.resume or args.evaluate:
        ckpt_path = (args.resume if args.resume else args.evaluate)

        if path.isfile(ckpt_path):
            print("==> Loading checkpoint '{}'".format(ckpt_path))
            ckpt = torch.load(ckpt_path)
            start_epoch = ckpt['epoch']
            error_best = ckpt['error']
            glob_step = ckpt['step']
            lr_now = ckpt['lr']
            # lr_now = args.lr
            ####
            if args.dataset == "h36m":
                for k in list(ckpt['state_dict'].keys()):
                    v = ckpt['state_dict'].pop(k)
                    if (type(k) == str) and ("nonlocal" in k):
                        k = k.replace("nonlocal","nonlocal_layer")
                    ckpt['state_dict'][k] = v

            model_pos.load_state_dict(ckpt['state_dict'])
            optimizer.load_state_dict(ckpt['optimizer'])
            print("==> Loaded checkpoint (Epoch: {} | Error: {})".format(start_epoch, error_best))

            if args.resume:
                ckpt_dir_path = path.dirname(ckpt_path)
                logger = Logger(path.join(ckpt_dir_path, 'log.txt'), resume=True)
        else:
            raise RuntimeError("==> No checkpoint found at '{}'".format(ckpt_path))
    else:
        start_epoch = 0
        error_best = None
        glob_step = 0
        lr_now = args.lr
        mv_str = "mv_" if args.multiview else ""
        ckpt_dir_path = path.join(args.checkpoint, args.dataset + "_" + mv_str + datetime.datetime.now().isoformat(timespec="seconds"))
        ckpt_dir_path = ckpt_dir_path.replace(":","-")

        if not path.exists(ckpt_dir_path):
            os.makedirs(ckpt_dir_path)
            print('==> Making checkpoint dir: {}'.format(ckpt_dir_path))

        logger = Logger(os.path.join(ckpt_dir_path, 'log.txt'))
        logger.set_names(['epoch', 'lr', 'loss_train', 'error_eval_p1', 'error_eval_p2'])

    if args.evaluate:
        print('==> Evaluating...')
        if args.dataset == 'h36m':
            if action_filter is None:
                action_filter = dataset.define_actions()

            errors_p1 = np.zeros(len(action_filter))
            errors_p2 = np.zeros(len(action_filter))

            for i, action in enumerate(action_filter):
                poses_valid, poses_valid_2d, actions_valid = fetch(subjects_test, dataset, keypoints, [action], stride)
                valid_loader = DataLoader(PoseGenerator(poses_valid, poses_valid_2d, actions_valid),
                                          batch_size=args.batch_size, shuffle=False,
                                          num_workers=args.num_workers, pin_memory=True)
                errors_p1[i], errors_p2[i] = evaluate(valid_loader, model_pos, device)
        elif "synmit" in args.dataset:
            if "h36m" in args.dataset:
                test_dataset = SynDataset17_h36m(dataset_path, image_set="test")
            else:
                test_dataset = SynDataset17(dataset_path, image_set="test")
            test_loader = DataLoader(test_dataset, batch_size=args.batch_size,
                                      shuffle=False, num_workers=args.num_workers, pin_memory=True)
            errors_p1, errors_p2 = evaluate(test_loader, model_pos, device)

        print('Protocol #1 (MPJPE)     action-wise average: {:.2f} (mm)'.format(np.mean(errors_p1).item()))
        print('Protocol #2 (REL-MPJPE) action-wise average: {:.2f} (mm)'.format(np.mean(errors_p2).item()))
        exit(0)

    epoch_loss = 1e5
    for epoch in range(start_epoch, args.epochs):
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr_now))

        # Train for one epoch
        epoch_loss, lr_now, glob_step = train(epoch_loss, train_loader, model_pos, criterion, optimizer, device, args.lr, lr_now,
                                              glob_step, args.lr_decay, args.lr_gamma, max_norm=args.max_norm, loss_3d=args.loss3d,
                                              earlyend=args.earlyend)

        # Evaluate
        error_eval_p1, error_eval_p2 = evaluate(valid_loader, model_pos, device)

        # Update log file
        logger.append([epoch + 1, lr_now, epoch_loss, error_eval_p1, error_eval_p2])

        # Save checkpoint
        if error_best is None or error_best > error_eval_p1:
            error_best = error_eval_p1
            save_ckpt({'epoch': epoch + 1, 'lr': lr_now, 'step': glob_step, 'state_dict': model_pos.state_dict(),
                       'optimizer': optimizer.state_dict(), 'error': error_eval_p1}, ckpt_dir_path, suffix='best')

        if (epoch + 1) % args.snapshot == 0:
            save_ckpt({'epoch': epoch + 1, 'lr': lr_now, 'step': glob_step, 'state_dict': model_pos.state_dict(),
                       'optimizer': optimizer.state_dict(), 'error': error_eval_p1}, ckpt_dir_path)

    logger.close()
    logger.plot(['loss_train', 'error_eval_p1'])
    savefig(path.join(ckpt_dir_path, 'log.eps'))

    return
Ejemplo n.º 9
0
def load_data(args):
    print("Loading dataset...")
    dataset_path = "data/data_3d_" + args.dataset + ".npz"
    if args.dataset == "h36m":
        from common.h36m_dataset import Human36mDataset
        dataset = Human36mDataset(dataset_path, args.keypoints)
    elif args.dataset.startswith('humaneva'):
        from common.humaneva_dataset import HumanEvaDataset
        dataset = HumanEvaDataset(dataset_path)
    else:
        raise KeyError("Invalid dataset")

    print("Preparing data...")
    for subject in dataset.subjects():
        for action in dataset[subject].keys():
            anim = dataset[subject][action]

            if "positions" in anim:
                positions_3d = []
                for cam in anim["cameras"]:
                    pos_3d = world_to_camera(anim["positions"], R=cam["orientation"], t=cam["translation"])
                    pos_3d[:, 1:] -= pos_3d[:, :1]  # Remove global offset, but keep trajectory in first position
                    positions_3d.append(pos_3d)
                anim["positions_3d"] = positions_3d

    print("Loading 2D detections...")
    keypoints = np.load("data/data_2d_" + args.dataset + "_" + args.keypoints + ".npz", allow_pickle=True)
    keypoints_metadata = keypoints["metadata"].item()
    keypoints_metadata.update({'layout_name': 'h36m'})
    keypoints_symmetry = keypoints_metadata["keypoints_symmetry"]

    if args.dataset.startswith('humaneva'):
        kps_left, kps_right = [2, 3, 4, 8, 9, 10], [5, 6, 7, 11, 12, 13]
    else:
        kps_left, kps_right = list(keypoints_symmetry[0]), list(keypoints_symmetry[1])

    joints_left, joints_right = list(dataset.skeleton().joints_left()), list(dataset.skeleton().joints_right())
    keypoints = keypoints["positions_2d"].item()

    for subject in dataset.subjects():
        assert subject in keypoints, 'Subject {} is missing from the 2D detections dataset'.format(subject)
        for action in dataset[subject].keys():
            assert action in keypoints[
                subject], 'Action {} of subject {} is missing from the 2D detections dataset'.format(action, subject)
            if "positions_3d" not in dataset[subject][action]:
                continue

            for cam_idx in range(len(keypoints[subject][action])):

                # We check for >= instead of == because some videos in H3.6M contain extra frames
                mocap_length = dataset[subject][action]["positions_3d"][cam_idx].shape[0]
                assert keypoints[subject][action][cam_idx].shape[0] >= mocap_length

                if keypoints[subject][action][cam_idx].shape[0] > mocap_length:
                    keypoints[subject][action][cam_idx] = keypoints[subject][action][cam_idx][:mocap_length]

            assert len(keypoints[subject][action]) == len(dataset[subject][action]["positions_3d"])

    for subject in keypoints.keys():
        for action in keypoints[subject]:
            for cam_idx, kps in enumerate(keypoints[subject][action]):
                # Normalize camera frame
                cam = dataset.cameras()[subject][cam_idx]

                # HumanEva dataset detected from Mask-Rcnn with 17 keypoints
                # https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/keypoints.py
                # Transform the format of MSCOCO to the format of Human3.6M
                if args.dataset.startswith('humaneva'):
                    kps_15 = np.zeros((kps.shape[0], 15, kps.shape[2]), dtype=np.float32)
                    kps_15[:, 0] = (kps[:, 11] + kps[:, 12]) / 2
                    kps_15[:, 1] = (kps[:, 5] + kps[:, 6]) / 2
                    kps_15[:, 2] = kps[:, 5]
                    kps_15[:, 3] = kps[:, 7]
                    kps_15[:, 4] = kps[:, 9]
                    kps_15[:, 5] = kps[:, 6]
                    kps_15[:, 6] = kps[:, 8]
                    kps_15[:, 7] = kps[:, 10]
                    kps_15[:, 8] = kps[:, 11]
                    kps_15[:, 9] = kps[:, 13]
                    kps_15[:, 10] = kps[:, 15]
                    kps_15[:, 11] = kps[:, 12]
                    kps_15[:, 12] = kps[:, 14]
                    kps_15[:, 13] = kps[:, 16]
                    kps_15[:, 14] = kps[:, 0]

                    kps_15[..., :2] = normalize_screen_coordinates(kps_15[..., :2], w=cam["res_w"], h=cam["res_h"])
                    keypoints[subject][action][cam_idx] = kps_15

                else:
                    kps[..., :2] = normalize_screen_coordinates(kps[..., :2], w=cam["res_w"], h=cam["res_h"])
                    keypoints[subject][action][cam_idx] = kps

    return keypoints, dataset, keypoints_metadata, kps_left, kps_right, joints_left, joints_right
Ejemplo n.º 10
0
def main():
    dataset_path = "./data/data_3d_h36m.npz"    # 加载数据
    from common.h36m_dataset import Human36mDataset
    dataset = Human36mDataset(dataset_path)
    dataset = read_3d_data(dataset)
    cudnn.benchmark = True
    device = torch.device("cpu")
    from models.sem_gcn import SemGCN
    from common.graph_utils import adj_mx_from_skeleton
    p_dropout = None
    adj = adj_mx_from_skeleton(dataset.skeleton())
    model_pos = SemGCN(adj, 128, num_layers=4, p_dropout=p_dropout,
                       nodes_group=dataset.skeleton().joints_group()).to(device)
    ckpt_path = "./checkpoint/pretrained/ckpt_semgcn_nonlocal_sh.pth.tar"
    ckpt = torch.load(ckpt_path, map_location='cpu')
    model_pos.load_state_dict(ckpt['state_dict'], False)
    model_pos.eval()
    # ============ 新增代码 ==============
    # 从项目处理2d数据的代码中输出的一个人体数据
    inputs_2d = [[483.0, 450], [503, 450], [503, 539], [496, 622], [469, 450], [462, 546], [469, 622], [483, 347],
                 [483, 326], [489, 264], [448, 347], [448, 408], [441, 463], [517, 347], [524, 408], [538, 463]]

    # # openpose的测试样例识别结果
    # inputs_2d = [[86.0, 137], [99, 128], [94, 127], [97, 110], [89, 105], [102, 129], [116, 116], [99, 110],
    #              [105, 93], [117, 69], [147, 63], [104, 93], [89, 69], [82, 38], [89, 139], [94, 140]]

    inputs_2d = np.array(inputs_2d)
    # inputs_2d[:, 1] = np.max(inputs_2d[:, 1]) - inputs_2d[:, 1]   # 变成正的人体姿态,原始数据为倒立的

    cam = dataset.cameras()['S1'][0]    # 获取相机参数
    inputs_2d[..., :2] = normalize_screen_coordinates(inputs_2d[..., :2], w=cam['res_w'], h=cam['res_h'])  # 2d坐标处理

    # 画出归一化屏幕坐标并且标记序号的二维关键点图像
    print(inputs_2d)    # 打印归一化后2d关键点坐标
    d_x = inputs_2d[:, 0]
    d_y = inputs_2d[:, 1]
    plt.figure()
    plt.scatter(d_x, d_y)
    for i, txt in enumerate(np.arange(inputs_2d.shape[0])):
        plt.annotate(txt, (d_x[i], d_y[i]))     # 标号
    # plt.show()      # 显示2d关键点归一化后的图像

    # 获取3d结果
    inputs_2d = torch.tensor(inputs_2d, dtype=torch.float32)    # 转换为张量
    outputs_3d = model_pos(inputs_2d).cpu()         # 加载模型
    outputs_3d[:, :, :] -= outputs_3d[:, :1, :]     # Remove global offset / 移除全球偏移
    predictions = [outputs_3d.detach().numpy()]     # 预测结果
    prediction = np.concatenate(predictions)[0]     # 累加取第一个
    # Invert camera transformation  / 反相机的转换
    prediction = camera_to_world(prediction, R=cam['orientation'], t=0)     # R和t的参数设置影响不大,有多种写法和选取的相机参数有关,有些S没有t等等问题
    prediction[:, 2] -= np.min(prediction[:, 2])    # 向上偏移min(prediction[:, 2]),作用是把坐标变为正数
    print('prediction')
    print(prediction)   # 打印画图的3d坐标
    plt.figure()
    ax = plt.subplot(111, projection='3d')  # 创建一个三维的绘图工程
    o_x = prediction[:, 0]
    o_y = prediction[:, 1]
    o_z = prediction[:, 2]
    print(o_x)
    print(o_y)
    print(o_z)
    ax.scatter(o_x, o_y, o_z)

    temp = o_x
    x = [temp[9], temp[8], temp[7], temp[10], temp[11], temp[12]]
    temp = o_y
    y = [temp[9], temp[8], temp[7], temp[10], temp[11], temp[12]]
    temp = o_z
    z = [temp[9], temp[8], temp[7], temp[10], temp[11], temp[12]]
    ax.plot(x, y, z)

    temp = o_x
    x = [temp[7], temp[0], temp[4], temp[5], temp[6]]
    temp = o_y
    y = [temp[7], temp[0], temp[4], temp[5], temp[6]]
    temp = o_z
    z = [temp[7], temp[0], temp[4], temp[5], temp[6]]
    ax.plot(x, y, z)

    temp = o_x
    x = [temp[0], temp[1], temp[2], temp[3]]
    temp = o_y
    y = [temp[0], temp[1], temp[2], temp[3]]
    temp = o_z
    z = [temp[0], temp[1], temp[2], temp[3]]
    ax.plot(x, y, z)

    temp = o_x
    x = [temp[7], temp[13], temp[14], temp[15]]
    temp = o_y
    y = [temp[7], temp[13], temp[14], temp[15]]
    temp = o_z
    z = [temp[7], temp[13], temp[14], temp[15]]
    ax.plot(x, y, z)

    # temp = o_x
    # x = [temp[0], temp[14]]
    # temp = o_y
    # y = [temp[0], temp[14]]
    # temp = o_z
    # z = [temp[0], temp[14]]
    # ax.plot(y, x, z)
    #
    # temp = o_x
    # x = [temp[0], temp[15]]
    # temp = o_y
    # y = [temp[0], temp[15]]
    # temp = o_z
    # z = [temp[0], temp[15]]
    # ax.plot(y, x, z)

    # 改变坐标比例的代码,该代码的效果是z坐标轴是其他坐标的两倍
    from matplotlib.pyplot import MultipleLocatort
    major_locator = MultipleLocator(0.5)
    ax.xaxis.set_major_locator(major_locator)
    ax.yaxis.set_major_locator(major_locator)
    ax.zaxis.set_major_locator(major_locator)
    ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([0.5, 0.5, 1, 1]))

    plt.show()
Ejemplo n.º 11
0
                positions /= 1000  # Meters instead of millimeters
                output[subject][canonical_name] = positions.astype('float32')

        print('Saving...')
        np.savez_compressed(output_filename, positions_3d=output)

        print('Done.')

    else:
        print('Please specify the dataset source')
        exit(0)

    # Create 2D pose file
    print("")
    print("Computing ground-truth 2D poses...")
    dataset = Human36mDataset(output_filename + ".npz")
    output_2d_poses = {}
    for subject in dataset.subjects():
        output_2d_poses[subject] = {}
        for action in dataset[subject].keys():
            anim = dataset[subject][action]

            positions_2d = []
            for cam in anim["cameras"]:
                pos_3d = world_to_camera(anim["positions"],
                                         R=cam["orientation"],
                                         t=cam["translation"])
                pos_2d = wrap(project_to_2d,
                              pos_3d,
                              cam["intrinsic"],
                              unsqueeze=True)
Ejemplo n.º 12
0
def data_preparation(args):
    """
    load the h36m dataset
    generate data loader for training posenet, poseaug, and cross-data evaluation
    """
    dataset_path = path.join('data', 'data_3d_' + args.dataset + '.npz')
    if args.dataset == 'h36m':
        from common.h36m_dataset import Human36mDataset, TEST_SUBJECTS
        dataset = Human36mDataset(dataset_path)
        if args.s1only:
            subjects_train = ['S1']
        else:
            subjects_train = ['S1', 'S5', 'S6', 'S7', 'S8']
        subjects_test = TEST_SUBJECTS
    else:
        raise KeyError('Invalid dataset')

    print('==> Loading 3D data...')
    dataset = read_3d_data(dataset)

    print('==> Loading 2D detections...')
    keypoints = create_2d_data(
        path.join('data',
                  'data_2d_' + args.dataset + '_' + args.keypoints + '.npz'),
        dataset)

    action_filter = None if args.actions == '*' else args.actions.split(',')
    if action_filter is not None:
        action_filter = map(lambda x: dataset.define_actions(x)[0],
                            action_filter)
        print('==> Selected actions: {}'.format(action_filter))

    stride = args.downsample

    ############################################
    # general 2D-3D pair dataset
    ############################################
    poses_train, poses_train_2d, actions_train, cams_train = fetch(
        subjects_train, dataset, keypoints, action_filter, stride)
    poses_valid, poses_valid_2d, actions_valid, cams_valid = fetch(
        subjects_test, dataset, keypoints, action_filter, stride)
    # prepare train loader for detected 2D.
    train_det2d3d_loader = DataLoader(PoseDataSet(poses_train, poses_train_2d,
                                                  actions_train, cams_train),
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)

    # prepare train loader for GT 2D - 3D, which will update by using projection.
    train_gt2d3d_loader = DataLoader(PoseDataSet(poses_train, poses_train_2d,
                                                 actions_train, cams_train),
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.num_workers,
                                     pin_memory=True)

    valid_loader = DataLoader(PoseDataSet(poses_valid, poses_valid_2d,
                                          actions_valid, cams_valid),
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_workers,
                              pin_memory=True)

    ############################################
    # data loader for GAN training
    ############################################
    target_2d_loader = DataLoader(PoseTarget(poses_train_2d),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)
    target_3d_loader = DataLoader(PoseTarget(poses_train),
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  pin_memory=True)

    ############################################
    # prepare cross dataset validation
    ############################################
    # 3DHP -  2929 version
    mpi3d_npz = np.load(
        'data_extra/test_set/test_3dhp.npz')  # this is the 2929 version
    tmp = mpi3d_npz
    mpi3d_loader = DataLoader(PoseBuffer([tmp['pose3d']], [tmp['pose2d']]),
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_workers,
                              pin_memory=True)

    return {
        'dataset': dataset,
        'train_det2d3d_loader': train_det2d3d_loader,
        'train_gt2d3d_loader': train_gt2d3d_loader,
        'target_2d_loader': target_2d_loader,
        'target_3d_loader': target_3d_loader,
        'H36M_test': valid_loader,
        'mpi3d_loader': mpi3d_loader,
        'action_filter': action_filter,
        'subjects_test': subjects_test,
        'keypoints': keypoints,
    }
def add_audio_given_start_inference_vis(args, val_loader, model, device, 
          epoch, model_writer, global_step):
    print('Loading dataset...')
    dataset_path = args.h36m_data

    from common.h36m_dataset import Human36mDataset
    dataset = Human36mDataset(dataset_path)

    cam = dataset.cameras()['S1'][0]
    # cam = dataset.cameras()[args.viz_subject][args.viz_camera]
    rot = dataset.cameras()['S1'][0]['orientation']

    fps = args.fps

    # switch to train mode
    model.eval()

    if args.save_generation:
        gen_times = args.gen_num_for_eval
    else:
        gen_times = 20

    nums_per_mp3 = args.num_per_mp3
    mp3_folder = "/data/jiaman/github/cvpr20_dance/json_data/block_20s_train_val_data_discrete300/block_mp3_files"
    with torch.no_grad():
        for i, (mask, pos_vec, pose3d_discrete_seq, pose3d_discrete_gt_seq, \
            mfcc_data, beat_data, mp3_name) in enumerate(val_loader):
            # BS X T X 48, BS X T X 48, BS X 1 X T, BS X 1 X T, BS X T X 48
            bs = pose3d_discrete_seq.size()[0]
            timesteps = pose3d_discrete_seq.size()[1]

            # Send to device
            pose3d_discrete_seq = pose3d_discrete_seq.to(device)
            pose3d_discrete_gt_seq = pose3d_discrete_gt_seq.to(device)

            # Vis GT
            if not args.zero_start:
                first_step = pose3d_discrete_seq.squeeze(0)[0, :] # 48
                first_step = first_step.unsqueeze(0).view(1, -1, 3).data.cpu().float() # 1 X 16 X 3
                gt_pose3d = pose3d_discrete_gt_seq.squeeze(0).view(timesteps, -1, 3).data.cpu().float() # T X 16 X 3
                mask = mask.squeeze(0).squeeze(0).unsqueeze(1).unsqueeze(2).float()
                act_len = mask.sum() + 1
                gt_pose3d = torch.cat((first_step, gt_pose3d), dim=0) # (T+1) X 16 X 3
            else:
                gt_pose3d = pose3d_discrete_gt_seq.squeeze(0).view(timesteps, -1, 3).data.cpu().float() # T X 16 X 3
                mask = mask.squeeze(0).squeeze(0).unsqueeze(1).unsqueeze(2).float()
                act_len = mask.sum()
            
            act_len = int(act_len)
            gt_pose3d = gt_pose3d[:act_len, :, :] # act_T X 16 X 3

            root_zeros = torch.zeros(act_len, 1, 3).float()
            root_zeros.fill_(144) # Depends on how many classes for classification
            gt_pose3d = torch.cat((root_zeros, gt_pose3d), dim=1) # T X 17 X 3
            gt_pose3d = gt_pose3d.data.cpu()
            gt_pose3d = np.array(gt_pose3d, dtype="float32")

            gt_pose3d = convert_discrete_to_coord(gt_pose3d) # T X 17 X 3
            gt_pose3d = gt_pose3d - np.expand_dims(gt_pose3d[:, 0, :], axis=1) # Make root to zeros
            save_gt_pose3d = np.array(gt_pose3d, dtype="float32")
            
            gt_pose3d = camera_to_world(save_gt_pose3d, R=rot, t=0)
            # We don't have the trajectory, but at least we can rebase the height
            gt_pose3d[:, :, 2] -= np.min(gt_pose3d[:, :, 2], axis=1, keepdims=True)

            # Vis generation
            init_start_pose = pose3d_discrete_seq[:, :args.start_steps, :] # BS X T' X 48

            # Add audio info
            for gen_idx in range(nums_per_mp3):
                if args.add_mfcc and args.add_beat:
                    mfcc_data_input = mfcc_data.to(device) # bs(1) X T X 26
                    beat_data_input = beat_data.to(device).long() # bs(1) X T
                    cal_start_time = time.time()
                    gen_pose3d = model.given_start_inference(init_start_pose, act_len, device, \
                    mfcc_feats=mfcc_data_input, beat_feats=beat_data_input) # 1 X T X 48
                    cal_end_time = time.time()
                    print("Total time for whole seq:{0}".format(cal_end_time-cal_start_time))
                    print("Mean time for each frame among whole seq:{0}".format((cal_end_time-cal_start_time)/act_len))
                elif args.add_mfcc:
                    mfcc_data_input = mfcc_data.to(device)
                    gen_pose3d = model.given_start_inference(init_start_pose, act_len, device, mfcc_feats=mfcc_data_input) # 1 X T X 48
                elif args.add_beat:
                    beat_data_input = beat_data.to(device).long()
                    gen_pose3d = model.given_start_inference(init_start_pose, act_len, device, beat_feats=beat_data_input) # 1 X T X 48
                else:
                    gen_pose3d = model.given_start_inference(init_start_pose, act_len, device) # 1 X T X 48

                gen_pose3d = gen_pose3d.squeeze(0) # T X 48
                gen_pose3d = gen_pose3d.view(act_len, -1, 3) # T X 16 X 3
                
                root_zeros = torch.zeros(act_len, 1, 3).float()
                root_zeros.fill_(144) # Depends on how many classes for classification
                gen_pose3d = torch.cat((root_zeros, gen_pose3d), dim=1) # T X 17 X 3
                gen_pose3d = gen_pose3d.data.cpu()
                gen_pose3d = np.array(gen_pose3d, dtype=np.float32)

                gen_pose3d = convert_discrete_to_coord(gen_pose3d) # T X 17 X 3
                gen_pose3d = gen_pose3d - np.expand_dims(gen_pose3d[:, 0, :], axis=1) # Make root to zeros
                save_gen_pose3d = np.array(gen_pose3d, dtype="float32")

                if args.save_generation:
                    if not os.path.exists(args.gen_res_folder):
                        os.makedirs(args.gen_res_folder)
                    dest_gen_npy_path = os.path.join(args.gen_res_folder, str(i)+"_"+str(gen_idx)+"_gen.npy") # Notice that beat is 1 step delay!!!!!
                    np.save(dest_gen_npy_path, save_gen_pose3d) # T X 17 X 3
                    dest_gt_npy_path = os.path.join(args.gen_res_folder, str(i)+"_gt.npy")
                    np.save(dest_gt_npy_path, save_gt_pose3d) # T X 17 X 3
                    if args.add_beat or args.add_mfcc:
                        dest_beat_path = os.path.join(args.gen_res_folder, str(i)+"_beat_mp3.json")
                        data_dict = {}
                        data_dict['beat'] = np.array(beat_data.data.cpu())[0, :act_len-1].tolist() # T - 1
                        cropped_mp3_path = os.path.join(mp3_folder, str(mp3_name[0])+".mp3")
                        data_dict['mp3'] = cropped_mp3_path
                        json.dump(data_dict, open(dest_beat_path, 'w'))
                else:
                    gen_pose3d = camera_to_world(save_gen_pose3d, R=rot, t=0)
                    # We don't have the trajectory, but at least we can rebase the height
                    gen_pose3d[:, :, 2] -= np.min(gen_pose3d[:, :, 2], axis=1, keepdims=True)

                    anim_output = {'GT Pose3D': gt_pose3d,
                    'Generated Pose3D': gen_pose3d}
                    from common.visualization import render_animation
                    input_keypoints = gen_pose3d[:, :, :2] # Just for keeping vis codes unchanged
                    if not os.path.exists(args.viz_folder):
                        os.makedirs(args.viz_folder)

                    render_animation(input_keypoints, anim_output,
                                     dataset.skeleton(), 
                                     fps,
                                     3000, cam['azimuth'], os.path.join(args.viz_folder, str(i)+"_"+str(gen_idx)+".mp4"),
                                     limit=-1, downsample=1, size=5,
                                     input_video_path=None, viewport=(cam['res_w'], cam['res_h']),
                                     input_video_skip=0)


                    # Merge visual mp4 with mp3 file
                    mp4_path = os.path.join(args.viz_folder, str(i)+"_"+str(gen_idx)+".mp4")

                    dest_folder = os.path.join(args.viz_folder, "merged_mp4")
                    if not os.path.exists(dest_folder):
                        os.makedirs(dest_folder)
                    merged_mp4_file_path = os.path.join(dest_folder, str(i)+"_"+str(gen_idx)+".mp4")

                    cropped_mp3_path = os.path.join(mp3_folder, str(mp3_name[0])+".mp3")
                    merge_audio_cmd = "ffmpeg -i " + mp4_path +" -i "+ cropped_mp3_path + " -c:v copy -c:a aac -strict experimental " + merged_mp4_file_path

                    subprocess.call(merge_audio_cmd, shell=True)

                    os.remove(mp4_path)
            
            if i >= gen_times:
                break;
Ejemplo n.º 14
0
def the_main_kaboose(args):
    print(args)

    try:
        # Create checkpoint directory if it does not exist
        os.makedirs(args.checkpoint)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise RuntimeError('Unable to create checkpoint directory:',
                               args.checkpoint)

    print('Loading dataset...')
    dataset_path = 'data/data_3d_' + args.dataset + '.npz'
    if args.dataset == 'h36m':
        from common.h36m_dataset import Human36mDataset
        dataset = Human36mDataset(dataset_path)
    elif args.dataset.startswith('humaneva'):
        from common.humaneva_dataset import HumanEvaDataset
        dataset = HumanEvaDataset(dataset_path)
    elif args.dataset.startswith('custom'):
        from common.custom_dataset import CustomDataset
        dataset = CustomDataset('data/data_2d_' + args.dataset + '_' +
                                args.keypoints + '.npz')
    else:
        raise KeyError('Invalid dataset')

    print('Preparing data...')
    for subject in dataset.subjects():
        for action in dataset[subject].keys():
            anim = dataset[subject][action]

            # this only works when training.
            if 'positions' in anim:
                positions_3d = []
                for cam in anim['cameras']:
                    pos_3d = world_to_camera(anim['positions'],
                                             R=cam['orientation'],
                                             t=cam['translation'])
                    pos_3d[:,
                           1:] -= pos_3d[:, :
                                         1]  # Remove global offset, but keep trajectory in first position
                    positions_3d.append(pos_3d)
                anim['positions_3d'] = positions_3d

    print('Loading 2D detections...')
    keypoints = np.load('data/data_2d_' + args.dataset + '_' + args.keypoints +
                        '.npz',
                        allow_pickle=True)
    keypoints_metadata = keypoints['metadata'].item()
    keypoints_symmetry = keypoints_metadata['keypoints_symmetry']
    kps_left, kps_right = list(keypoints_symmetry[0]), list(
        keypoints_symmetry[1])
    joints_left, joints_right = list(dataset.skeleton().joints_left()), list(
        dataset.skeleton().joints_right())
    keypoints = keypoints['positions_2d'].item()

    # THIS IS ABOUT TRAINING. ignore pls.
    for subject in dataset.subjects():
        assert subject in keypoints, 'Subject {} is missing from the 2D detections dataset'.format(
            subject)
        for action in dataset[subject].keys():
            assert action in keypoints[
                subject], 'Action {} of subject {} is missing from the 2D detections dataset'.format(
                    action, subject)
            if 'positions_3d' not in dataset[subject][action]:
                continue

            for cam_idx in range(len(keypoints[subject][action])):

                # We check for >= instead of == because some videos in H3.6M contain extra frames
                mocap_length = dataset[subject][action]['positions_3d'][
                    cam_idx].shape[0]
                assert keypoints[subject][action][cam_idx].shape[
                    0] >= mocap_length

                if keypoints[subject][action][cam_idx].shape[0] > mocap_length:
                    # Shorten sequence
                    keypoints[subject][action][cam_idx] = keypoints[subject][
                        action][cam_idx][:mocap_length]

            assert len(keypoints[subject][action]) == len(
                dataset[subject][action]['positions_3d'])

    # normalize camera frame?
    for subject in keypoints.keys():
        for action in keypoints[subject]:
            for cam_idx, kps in enumerate(keypoints[subject][action]):
                # Normalize camera frame
                cam = dataset.cameras()[subject][cam_idx]
                kps[..., :2] = normalize_screen_coordinates(kps[..., :2],
                                                            w=cam['res_w'],
                                                            h=cam['res_h'])
                keypoints[subject][action][cam_idx] = kps

    subjects_train = args.subjects_train.split(',')
    subjects_semi = [] if not args.subjects_unlabeled else args.subjects_unlabeled.split(
        ',')
    if not args.render:
        subjects_test = args.subjects_test.split(',')
    else:
        subjects_test = [args.viz_subject]

    semi_supervised = len(subjects_semi) > 0
    if semi_supervised and not dataset.supports_semi_supervised():
        raise RuntimeError(
            'Semi-supervised training is not implemented for this dataset')

    def fetch(subjects, action_filter=None, subset=1, parse_3d_poses=True):
        out_poses_3d = []
        out_poses_2d = []
        out_camera_params = []
        for subject in subjects:
            print("gonna check actions for subject " + subject)

        for subject in subjects:
            for action in keypoints[subject].keys():
                if action_filter is not None:
                    found = False
                    for a in action_filter:
                        if action.startswith(a):
                            found = True
                            break
                    if not found:
                        continue

                poses_2d = keypoints[subject][action]
                for i in range(len(poses_2d)):  # Iterate across cameras
                    out_poses_2d.append(poses_2d[i])

                if subject in dataset.cameras():
                    cams = dataset.cameras()[subject]
                    assert len(cams) == len(poses_2d), 'Camera count mismatch'
                    for cam in cams:
                        if 'intrinsic' in cam:
                            out_camera_params.append(cam['intrinsic'])

                if parse_3d_poses and 'positions_3d' in dataset[subject][
                        action]:
                    poses_3d = dataset[subject][action]['positions_3d']
                    assert len(poses_3d) == len(
                        poses_2d), 'Camera count mismatch'
                    for i in range(len(poses_3d)):  # Iterate across cameras
                        out_poses_3d.append(poses_3d[i])

        if len(out_camera_params) == 0:
            out_camera_params = None
        if len(out_poses_3d) == 0:
            out_poses_3d = None

        stride = args.downsample
        if subset < 1:
            for i in range(len(out_poses_2d)):
                n_frames = int(
                    round(len(out_poses_2d[i]) // stride * subset) * stride)
                start = deterministic_random(
                    0,
                    len(out_poses_2d[i]) - n_frames + 1,
                    str(len(out_poses_2d[i])))
                out_poses_2d[i] = out_poses_2d[i][start:start +
                                                  n_frames:stride]
                if out_poses_3d is not None:
                    out_poses_3d[i] = out_poses_3d[i][start:start +
                                                      n_frames:stride]
        elif stride > 1:
            # Downsample as requested
            for i in range(len(out_poses_2d)):
                out_poses_2d[i] = out_poses_2d[i][::stride]
                if out_poses_3d is not None:
                    out_poses_3d[i] = out_poses_3d[i][::stride]

        return out_camera_params, out_poses_3d, out_poses_2d

    action_filter = None if args.actions == '*' else args.actions.split(',')
    if action_filter is not None:
        print('Selected actions:', action_filter)

    # when you run inference, this returns None, None, and the keypoints array renamed as poses_valid_2d
    cameras_valid, poses_valid, poses_valid_2d = fetch(subjects_test,
                                                       action_filter)

    filter_widths = [int(x) for x in args.architecture.split(',')]
    if not args.disable_optimizations and not args.dense and args.stride == 1:
        # Use optimized model for single-frame predictions
        shape_2 = poses_valid_2d[0].shape[-2]
        shape_1 = poses_valid_2d[0].shape[-1]
        numJoints = dataset.skeleton().num_joints()
        model_pos_train = TemporalModelOptimized1f(shape_2,
                                                   shape_1,
                                                   numJoints,
                                                   filter_widths=filter_widths,
                                                   causal=args.causal,
                                                   dropout=args.dropout,
                                                   channels=args.channels)
    else:
        # When incompatible settings are detected (stride > 1, dense filters, or disabled optimization) fall back to normal model
        model_pos_train = TemporalModel(poses_valid_2d[0].shape[-2],
                                        poses_valid_2d[0].shape[-1],
                                        dataset.skeleton().num_joints(),
                                        filter_widths=filter_widths,
                                        causal=args.causal,
                                        dropout=args.dropout,
                                        channels=args.channels,
                                        dense=args.dense)

    model_pos = TemporalModel(poses_valid_2d[0].shape[-2],
                              poses_valid_2d[0].shape[-1],
                              dataset.skeleton().num_joints(),
                              filter_widths=filter_widths,
                              causal=args.causal,
                              dropout=args.dropout,
                              channels=args.channels,
                              dense=args.dense)

    receptive_field = model_pos.receptive_field()
    print('INFO: Receptive field: {} frames'.format(receptive_field))
    pad = (receptive_field - 1) // 2  # Padding on each side
    if args.causal:
        print('INFO: Using causal convolutions')
        causal_shift = pad
    else:
        causal_shift = 0

    model_params = 0
    for parameter in model_pos.parameters():
        model_params += parameter.numel()
    print('INFO: Trainable parameter count:', model_params)

    if torch.cuda.is_available():
        model_pos = model_pos.cuda()
        model_pos_train = model_pos_train.cuda()

    if args.resume or args.evaluate:
        chk_filename = os.path.join(
            args.checkpoint, args.resume if args.resume else args.evaluate)
        print('Loading checkpoint', chk_filename)
        checkpoint = torch.load(chk_filename,
                                map_location=lambda storage, loc: storage)
        print('This model was trained for {} epochs'.format(
            checkpoint['epoch']))
        model_pos_train.load_state_dict(checkpoint['model_pos'])
        model_pos.load_state_dict(checkpoint['model_pos'])

        if args.evaluate and 'model_traj' in checkpoint:
            # Load trajectory model if it contained in the checkpoint (e.g. for inference in the wild)
            model_traj = TemporalModel(poses_valid_2d[0].shape[-2],
                                       poses_valid_2d[0].shape[-1],
                                       1,
                                       filter_widths=filter_widths,
                                       causal=args.causal,
                                       dropout=args.dropout,
                                       channels=args.channels,
                                       dense=args.dense)
            if torch.cuda.is_available():
                model_traj = model_traj.cuda()
            model_traj.load_state_dict(checkpoint['model_traj'])
        else:
            model_traj = None

    test_generator = UnchunkedGenerator(cameras_valid,
                                        poses_valid,
                                        poses_valid_2d,
                                        pad=pad,
                                        causal_shift=causal_shift,
                                        augment=False,
                                        kps_left=kps_left,
                                        kps_right=kps_right,
                                        joints_left=joints_left,
                                        joints_right=joints_right)
    print('INFO: Testing on {} frames'.format(test_generator.num_frames()))

    # Evaluate
    def evaluate(eval_generator,
                 action=None,
                 return_predictions=False,
                 use_trajectory_model=False):
        epoch_loss_3d_pos = 0
        epoch_loss_3d_pos_procrustes = 0
        epoch_loss_3d_pos_scale = 0
        epoch_loss_3d_vel = 0
        with torch.no_grad():
            if not use_trajectory_model:
                model_pos.eval()
            else:
                model_traj.eval()
            N = 0
            for _, batch, batch_2d in eval_generator.next_epoch():
                inputs_2d = torch.from_numpy(batch_2d.astype('float32'))
                if torch.cuda.is_available():
                    inputs_2d = inputs_2d.cuda()

                # Positional model
                if not use_trajectory_model:
                    predicted_3d_pos = model_pos(inputs_2d)
                else:
                    predicted_3d_pos = model_traj(inputs_2d)

                # Test-time augmentation (if enabled)
                if eval_generator.augment_enabled():
                    # Undo flipping and take average with non-flipped version
                    predicted_3d_pos[1, :, :, 0] *= -1
                    if not use_trajectory_model:
                        predicted_3d_pos[1, :, joints_left +
                                         joints_right] = predicted_3d_pos[
                                             1, :, joints_right + joints_left]
                    predicted_3d_pos = torch.mean(predicted_3d_pos,
                                                  dim=0,
                                                  keepdim=True)

                if return_predictions:
                    return predicted_3d_pos.squeeze(0).cpu().numpy()

                inputs_3d = torch.from_numpy(batch.astype('float32'))
                if torch.cuda.is_available():
                    inputs_3d = inputs_3d.cuda()
                inputs_3d[:, :, 0] = 0
                if eval_generator.augment_enabled():
                    inputs_3d = inputs_3d[:1]

                error = mpjpe(predicted_3d_pos, inputs_3d)
                epoch_loss_3d_pos_scale += inputs_3d.shape[
                    0] * inputs_3d.shape[1] * n_mpjpe(predicted_3d_pos,
                                                      inputs_3d).item()

                epoch_loss_3d_pos += inputs_3d.shape[0] * inputs_3d.shape[
                    1] * error.item()
                N += inputs_3d.shape[0] * inputs_3d.shape[1]

                inputs = inputs_3d.cpu().numpy().reshape(
                    -1, inputs_3d.shape[-2], inputs_3d.shape[-1])
                predicted_3d_pos = predicted_3d_pos.cpu().numpy().reshape(
                    -1, inputs_3d.shape[-2], inputs_3d.shape[-1])

                epoch_loss_3d_pos_procrustes += inputs_3d.shape[
                    0] * inputs_3d.shape[1] * p_mpjpe(predicted_3d_pos, inputs)

                # Compute velocity error
                epoch_loss_3d_vel += inputs_3d.shape[0] * inputs_3d.shape[
                    1] * mean_velocity_error(predicted_3d_pos, inputs)

        if action is None:
            print('----------')
        else:
            print('----' + action + '----')
        e1 = (epoch_loss_3d_pos / N) * 1000
        e2 = (epoch_loss_3d_pos_procrustes / N) * 1000
        e3 = (epoch_loss_3d_pos_scale / N) * 1000
        ev = (epoch_loss_3d_vel / N) * 1000
        print('Test time augmentation:', eval_generator.augment_enabled())
        print('Protocol #1 Error (MPJPE):', e1, 'mm')
        print('Protocol #2 Error (P-MPJPE):', e2, 'mm')
        print('Protocol #3 Error (N-MPJPE):', e3, 'mm')
        print('Velocity Error (MPJVE):', ev, 'mm')
        print('----------')

        return e1, e2, e3, ev

    if args.render:
        print('Rendering...')

        input_keypoints = keypoints[args.viz_subject][args.viz_action][
            args.viz_camera].copy()
        ground_truth = None
        if args.viz_subject in dataset.subjects(
        ) and args.viz_action in dataset[args.viz_subject]:
            if 'positions_3d' in dataset[args.viz_subject][args.viz_action]:
                ground_truth = dataset[args.viz_subject][
                    args.viz_action]['positions_3d'][args.viz_camera].copy()
        if ground_truth is None:
            print(
                'INFO: this action is unlabeled. Ground truth will not be rendered.'
            )

        gen = UnchunkedGenerator(None,
                                 None, [input_keypoints],
                                 pad=pad,
                                 causal_shift=causal_shift,
                                 augment=args.test_time_augmentation,
                                 kps_left=kps_left,
                                 kps_right=kps_right,
                                 joints_left=joints_left,
                                 joints_right=joints_right)
        prediction = evaluate(gen, return_predictions=True)
        if model_traj is not None and ground_truth is None:
            prediction_traj = evaluate(gen,
                                       return_predictions=True,
                                       use_trajectory_model=True)
            prediction += prediction_traj

        if args.viz_export is not None:
            print('Exporting joint positions to', args.viz_export)
            # Predictions are in camera space
            np.save(args.viz_export, prediction)

        if args.viz_output is not None:
            if ground_truth is not None:
                # Reapply trajectory
                trajectory = ground_truth[:, :1]
                ground_truth[:, 1:] += trajectory
                prediction += trajectory

            # Invert camera transformation
            cam = dataset.cameras()[args.viz_subject][args.viz_camera]
            if ground_truth is not None:
                prediction = camera_to_world(prediction,
                                             R=cam['orientation'],
                                             t=cam['translation'])
                ground_truth = camera_to_world(ground_truth,
                                               R=cam['orientation'],
                                               t=cam['translation'])
            else:
                # If the ground truth is not available, take the camera extrinsic params from a random subject.
                # They are almost the same, and anyway, we only need this for visualization purposes.
                for subject in dataset.cameras():
                    if 'orientation' in dataset.cameras()[subject][
                            args.viz_camera]:
                        rot = dataset.cameras()[subject][
                            args.viz_camera]['orientation']
                        break
                prediction = camera_to_world(prediction, R=rot, t=0)
                # We don't have the trajectory, but at least we can rebase the height
                prediction[:, :, 2] -= np.min(prediction[:, :, 2])

            anim_output = {'Reconstruction': prediction}
            if ground_truth is not None and not args.viz_no_ground_truth:
                anim_output['Ground truth'] = ground_truth

            input_keypoints = image_coordinates(input_keypoints[..., :2],
                                                w=cam['res_w'],
                                                h=cam['res_h'])

            print("Writing to json")

            import json
            # format the data in the same format as mediapipe, so we can load it in unity with the same script
            # we need a list (frames) of lists of 3d landmarks.
            # but prediction[] only has 17 landmarks, and we need 25 in our unity script
            unity_landmarks = prediction.tolist()

            with open(args.output_json, "w") as json_file:
                json.dump(unity_landmarks, json_file)

            if args.rendervideo == "yes":

                from common.visualization import render_animation
                render_animation(input_keypoints,
                                 keypoints_metadata,
                                 anim_output,
                                 dataset.skeleton(),
                                 dataset.fps(),
                                 args.viz_bitrate,
                                 cam['azimuth'],
                                 args.viz_output,
                                 limit=args.viz_limit,
                                 downsample=args.viz_downsample,
                                 size=args.viz_size,
                                 input_video_path=args.viz_video,
                                 viewport=(cam['res_w'], cam['res_h']),
                                 input_video_skip=args.viz_skip)