def main(args):
    print('==> Using settings {}'.format(args))

    print('==> Loading dataset...')
    dataset_path = path.join('data', 'data_3d_' + args.dataset + '.npz')
    if args.dataset == 'h36m':
        from common.h36m_dataset import Human36mDataset, TRAIN_SUBJECTS, TEST_SUBJECTS
        dataset = Human36mDataset(dataset_path)
        subjects_train = TRAIN_SUBJECTS
        subjects_test = TEST_SUBJECTS
    else:
        raise KeyError('Invalid dataset')

    print('==> Preparing data...')
    dataset = read_3d_data(dataset)

    print('==> Loading 2D detections...')
    keypoints = create_2d_data(
        path.join('data',
                  'data_2d_' + args.dataset + '_' + args.keypoints + '.npz'),
        dataset)

    action_filter = None if args.actions == '*' else args.actions.split(',')
    if action_filter is not None:
        action_filter = map(lambda x: dataset.define_actions(x)[0],
                            action_filter)
        print('==> Selected actions: {}'.format(action_filter))

    stride = args.downsample
    cudnn.benchmark = True
    device = torch.device("cuda")

    # Create model
    print("==> Creating model...")

    adj, adj_mutual = adj_mx_from_skeleton(
        dataset.skeleton())  ##multi-person adj-matrix
    model_pos = SemGCN(adj,
                       adj_mutual,
                       args.hid_dim,
                       num_layers=args.num_layers,
                       nodes_group=dataset.skeleton().joints_group()
                       if args.non_local else None).to(device)
    print("==> Total parameters: {:.2f}M".format(
        sum(p.numel() for p in model_pos.parameters()) / 1000000.0))

    criterion = nn.MSELoss(reduction='mean').to(device)
    optimizer = torch.optim.Adam(model_pos.parameters(), lr=args.lr)

    # Optionally resume from a checkpoint
    if args.resume or args.evaluate:
        ckpt_path = path.join(args.resume if args.resume else args.evaluate)

        if path.isfile(ckpt_path):
            print("=> Loading checkpoint '{}'".format(ckpt_path))
            ckpt = torch.load(ckpt_path)
            start_epoch = ckpt['epoch']
            error_best = ckpt['error']
            glob_step = ckpt['step']
            lr_now = ckpt['lr']
            model_pos.load_state_dict(ckpt['state_dict'])
            optimizer.load_state_dict(ckpt['optimizer'])
            print("=> Loaded checkpoint (Epoch: {} | Error: {})".format(
                start_epoch, error_best))

            if args.resume:
                ckpt_dir_path = path.dirname(ckpt_path)
                logger = Logger(path.join(ckpt_dir_path, 'log.txt'),
                                resume=True)
        else:
            raise RuntimeError(
                "=> No checkpoint found at '{}'".format(ckpt_path))
    else:
        start_epoch = 0
        error_best = None
        glob_step = 0
        lr_now = args.lr
        ckpt_dir_path = path.join(
            args.checkpoint,
            datetime.datetime.now().isoformat() +
            "_l_%04d_hid_%04d_e_%04d_non_local_%d" %
            (args.num_layers, args.hid_dim, args.epochs, args.non_local))

        if not path.exists(ckpt_dir_path):
            os.makedirs(ckpt_dir_path)
            print('=> Making checkpoint dir: {}'.format(ckpt_dir_path))

        logger = Logger(os.path.join(ckpt_dir_path, 'log.txt'))
        logger.set_names(
            ['epoch', 'lr', 'loss_train', 'error_eval_p1', 'error_eval_p2'])

    if args.evaluate:
        print('==> Evaluating...')

        if action_filter is None:
            action_filter = dataset.define_actions()

        errors_p1 = np.zeros(len(action_filter))
        errors_p2 = np.zeros(len(action_filter))

        for i, action in enumerate(action_filter):
            poses_valid, poses_valid_2d, actions_valid = fetch(
                subjects_test, dataset, keypoints, [action], stride)
            valid_loader = DataLoader(PoseGenerator(poses_valid,
                                                    poses_valid_2d,
                                                    actions_valid),
                                      batch_size=args.batch_size,
                                      shuffle=False,
                                      num_workers=args.num_workers,
                                      pin_memory=True)
            errors_p1[i], errors_p2[i] = evaluate(valid_loader, model_pos,
                                                  device)

        print('Protocol #1   (MPJPE) action-wise average: {:.2f} (mm)'.format(
            np.mean(errors_p1).item()))
        print('Protocol #2 (P-MPJPE) action-wise average: {:.2f} (mm)'.format(
            np.mean(errors_p2).item()))
        exit(0)

    poses_train, poses_train_2d, actions_train = fetch(subjects_train, dataset,
                                                       keypoints,
                                                       action_filter, stride)
    train_loader = DataLoader(PoseGenerator(poses_train, poses_train_2d,
                                            actions_train),
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              pin_memory=True)

    poses_valid, poses_valid_2d, actions_valid = fetch(subjects_test, dataset,
                                                       keypoints,
                                                       action_filter, stride)
    valid_loader = DataLoader(PoseGenerator(poses_valid, poses_valid_2d,
                                            actions_valid),
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_workers,
                              pin_memory=True)

    for epoch in range(start_epoch, args.epochs):
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr_now))

        # Train for one epoch
        epoch_loss, lr_now, glob_step = train(train_loader,
                                              model_pos,
                                              criterion,
                                              optimizer,
                                              device,
                                              args.lr,
                                              lr_now,
                                              glob_step,
                                              args.lr_decay,
                                              args.lr_gamma,
                                              max_norm=args.max_norm)

        # Evaluate
        error_eval_p1, error_eval_p2 = evaluate(valid_loader, model_pos,
                                                device)

        # Update log file
        logger.append(
            [epoch + 1, lr_now, epoch_loss, error_eval_p1, error_eval_p2])

        # Save checkpoint
        if error_best is None or error_best > error_eval_p1:
            error_best = error_eval_p1
            save_ckpt(
                {
                    'epoch': epoch + 1,
                    'lr': lr_now,
                    'step': glob_step,
                    'state_dict': model_pos.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'error': error_eval_p1
                },
                ckpt_dir_path,
                suffix='best')

        if (epoch + 1) % args.snapshot == 0:
            save_ckpt(
                {
                    'epoch': epoch + 1,
                    'lr': lr_now,
                    'step': glob_step,
                    'state_dict': model_pos.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'error': error_eval_p1
                }, ckpt_dir_path)

    logger.close()
    logger.plot(['loss_train', 'error_eval_p1'])
    savefig(path.join(ckpt_dir_path, 'log.eps'))

    return
Exemple #2
0
def main(args):
    print('==> Using settings {}'.format(args))

    print('==> Loading dataset...')
    if args.dataset == 'h36m':
        dataset_path = path.join('data', 'data_3d_' + args.dataset + '.npz')
        from common.h36m_dataset import Human36mDataset, TRAIN_SUBJECTS, TEST_SUBJECTS
        dataset = Human36mDataset(dataset_path)
        subjects_train = TRAIN_SUBJECTS
        subjects_test = TEST_SUBJECTS

        print('==> Preparing data ' + args.dataset + "...")
        dataset = read_3d_data(dataset)
        adj = adj_mx_from_skeleton(dataset.skeleton())
        nodes_group = dataset.skeleton().joints_group() if args.non_local else None

        print('==> Loading 2D detections...')
        keypoints = create_2d_data(path.join('data', 'data_2d_' + args.dataset + '_' + args.keypoints + '.npz'), dataset)
        action_filter = None if args.actions == '*' else args.actions.split(',')

        stride = args.downsample
        if action_filter is not None:
            action_filter = map(lambda x: dataset.define_actions(x)[0], action_filter)
            print('==> Selected actions: {}'.format(action_filter))

        if not args.evaluate:
            print('==> Build DataLoader...')
            poses_train, poses_train_2d, actions_train = fetch(subjects_train, dataset, keypoints, action_filter, stride)
            train_loader = DataLoader(PoseGenerator(poses_train, poses_train_2d, actions_train), batch_size=args.batch_size,
                                      shuffle=True, num_workers=args.num_workers, pin_memory=True)
            poses_valid, poses_valid_2d, actions_valid = fetch(subjects_test, dataset, keypoints, action_filter, stride)
            valid_loader = DataLoader(PoseGenerator(poses_valid, poses_valid_2d, actions_valid), batch_size=args.batch_size,
                                      shuffle=False, num_workers=args.num_workers, pin_memory=True)
    elif "synmit" in args.dataset:
        dataset_path = args.dataset_path
        from common.synmit_dataset import SynDataset17, SynDataset17_h36m
        if "h36m" in args.dataset:
            train_dataset = SynDataset17_h36m(dataset_path, image_set="train")
            valid_dataset = SynDataset17_h36m(dataset_path, image_set="val")
        else:
            train_dataset = SynDataset17(dataset_path, image_set="train")
            valid_dataset = SynDataset17(dataset_path, image_set="val")
        dataset = train_dataset
        adj = adj_mx_from_edges(dataset.jointcount(), dataset.edge(), sparse=False)
        nodes_group = dataset.nodesgroup() if args.non_local else None
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
                                  shuffle=True, num_workers=args.num_workers, pin_memory=True)
        valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size,
                                  shuffle=False, num_workers=args.num_workers, pin_memory=True)
    else:
        raise KeyError('Invalid dataset')


    cudnn.benchmark = True
    device = torch.device("cuda")

    # Create model
    print("==> Creating model...")
    p_dropout = (None if args.dropout == 0.0 else args.dropout)
    view_count = dataset.view_count if args.multiview else None
    model_pos = MultiviewSemGCN(adj, args.hid_dim, coords_dim=(3, 1), num_layers=args.num_layers, p_dropout=p_dropout,
                                view_count=view_count, nodes_group=nodes_group).to(device)
    print("==> Total parameters: {:.2f}M".format(sum(p.numel() for p in model_pos.parameters()) / 1000000.0))

    criterion = nn.MSELoss(reduction='mean').to(device)
    optimizer = torch.optim.Adam(model_pos.parameters(), lr=args.lr)

    # Optionally resume from a checkpoint
    if args.resume or args.evaluate:
        ckpt_path = (args.resume if args.resume else args.evaluate)

        if path.isfile(ckpt_path):
            print("==> Loading checkpoint '{}'".format(ckpt_path))
            ckpt = torch.load(ckpt_path)
            start_epoch = ckpt['epoch']
            error_best = ckpt['error']
            glob_step = ckpt['step']
            lr_now = ckpt['lr']
            # lr_now = args.lr
            ####
            if args.dataset == "h36m":
                for k in list(ckpt['state_dict'].keys()):
                    v = ckpt['state_dict'].pop(k)
                    if (type(k) == str) and ("nonlocal" in k):
                        k = k.replace("nonlocal","nonlocal_layer")
                    ckpt['state_dict'][k] = v

            model_pos.load_state_dict(ckpt['state_dict'])
            optimizer.load_state_dict(ckpt['optimizer'])
            print("==> Loaded checkpoint (Epoch: {} | Error: {})".format(start_epoch, error_best))

            if args.resume:
                ckpt_dir_path = path.dirname(ckpt_path)
                logger = Logger(path.join(ckpt_dir_path, 'log.txt'), resume=True)
        else:
            raise RuntimeError("==> No checkpoint found at '{}'".format(ckpt_path))
    else:
        start_epoch = 0
        error_best = None
        glob_step = 0
        lr_now = args.lr
        mv_str = "mv_" if args.multiview else ""
        ckpt_dir_path = path.join(args.checkpoint, args.dataset + "_" + mv_str + datetime.datetime.now().isoformat(timespec="seconds"))
        ckpt_dir_path = ckpt_dir_path.replace(":","-")

        if not path.exists(ckpt_dir_path):
            os.makedirs(ckpt_dir_path)
            print('==> Making checkpoint dir: {}'.format(ckpt_dir_path))

        logger = Logger(os.path.join(ckpt_dir_path, 'log.txt'))
        logger.set_names(['epoch', 'lr', 'loss_train', 'error_eval_p1', 'error_eval_p2'])

    if args.evaluate:
        print('==> Evaluating...')
        if args.dataset == 'h36m':
            if action_filter is None:
                action_filter = dataset.define_actions()

            errors_p1 = np.zeros(len(action_filter))
            errors_p2 = np.zeros(len(action_filter))

            for i, action in enumerate(action_filter):
                poses_valid, poses_valid_2d, actions_valid = fetch(subjects_test, dataset, keypoints, [action], stride)
                valid_loader = DataLoader(PoseGenerator(poses_valid, poses_valid_2d, actions_valid),
                                          batch_size=args.batch_size, shuffle=False,
                                          num_workers=args.num_workers, pin_memory=True)
                errors_p1[i], errors_p2[i] = evaluate(valid_loader, model_pos, device)
        elif "synmit" in args.dataset:
            if "h36m" in args.dataset:
                test_dataset = SynDataset17_h36m(dataset_path, image_set="test")
            else:
                test_dataset = SynDataset17(dataset_path, image_set="test")
            test_loader = DataLoader(test_dataset, batch_size=args.batch_size,
                                      shuffle=False, num_workers=args.num_workers, pin_memory=True)
            errors_p1, errors_p2 = evaluate(test_loader, model_pos, device)

        print('Protocol #1 (MPJPE)     action-wise average: {:.2f} (mm)'.format(np.mean(errors_p1).item()))
        print('Protocol #2 (REL-MPJPE) action-wise average: {:.2f} (mm)'.format(np.mean(errors_p2).item()))
        exit(0)

    epoch_loss = 1e5
    for epoch in range(start_epoch, args.epochs):
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr_now))

        # Train for one epoch
        epoch_loss, lr_now, glob_step = train(epoch_loss, train_loader, model_pos, criterion, optimizer, device, args.lr, lr_now,
                                              glob_step, args.lr_decay, args.lr_gamma, max_norm=args.max_norm, loss_3d=args.loss3d,
                                              earlyend=args.earlyend)

        # Evaluate
        error_eval_p1, error_eval_p2 = evaluate(valid_loader, model_pos, device)

        # Update log file
        logger.append([epoch + 1, lr_now, epoch_loss, error_eval_p1, error_eval_p2])

        # Save checkpoint
        if error_best is None or error_best > error_eval_p1:
            error_best = error_eval_p1
            save_ckpt({'epoch': epoch + 1, 'lr': lr_now, 'step': glob_step, 'state_dict': model_pos.state_dict(),
                       'optimizer': optimizer.state_dict(), 'error': error_eval_p1}, ckpt_dir_path, suffix='best')

        if (epoch + 1) % args.snapshot == 0:
            save_ckpt({'epoch': epoch + 1, 'lr': lr_now, 'step': glob_step, 'state_dict': model_pos.state_dict(),
                       'optimizer': optimizer.state_dict(), 'error': error_eval_p1}, ckpt_dir_path)

    logger.close()
    logger.plot(['loss_train', 'error_eval_p1'])
    savefig(path.join(ckpt_dir_path, 'log.eps'))

    return