Example #1
0
def test_one_scene_pseudo(pseudo_labels, dataloader, save_folder, start_index):

    os.makedirs(save_folder, exist_ok=True)

    for iter_idx, (images_batch, proj_mats_batch, joints_3d_gt_batch, joints_3d_valid_batch, joints_2d_gt_batch, indexes) \
            in enumerate(dataloader):
        
        if images_batch is None:
            continue
        
        batch_size = images_batch.shape[0]
        indexes_abs = [idx + start_index for idx in indexes]

        p = 0.2 # percentage
        score_thresh = consistency.get_score_thresh(pseudo_labels, p, separate=True)
        joints_2d_pl_batch, joints_2d_valid_batch = \
                                 consistency.get_pseudo_labels(pseudo_labels, indexes_abs, images_batch.shape[1], score_thresh)
        
        for batch_i in range(batch_size):
            vis = visualize.visualize_pseudo_labels(images_batch[batch_i], joints_2d_pl_batch[batch_i], joints_2d_valid_batch[batch_i])
            im = Image.fromarray(vis)
            print(save_folder)
            print("%06d.png" % (iter_idx * batch_size + batch_i))
            img_path = os.path.join(save_folder, "%06d.png" % (iter_idx * batch_size + batch_i))
            im.save(img_path)
Example #2
0
def train_one_epoch_ssl(config, model, syn_train_loader, real_train_loader, \
                        criterion, metric, opt, e, device, \
                        checkpoint_dir, writer=None, \
                        gamma=10, \
                        log_every_iters=1, vis_every_iters=1):
    model.train()
    batch_size = syn_train_loader.batch_size
    iters_per_epoch = round(min(syn_train_loader.dataset.__len__() / syn_train_loader.batch_size, \
                                real_train_loader.dataset.__len__() / real_train_loader.batch_size))
    print("Estimated iterations per epoch is %d." % iters_per_epoch)

    total_train_loss_syn = 0
    total_detected_syn = 0
    total_error_syn = 0
    total_samples_syn = 0  # num_joints or num_frames

    total_train_loss_real = 0
    total_detected_real = 0
    total_error_real = 0
    total_samples_real = 0  # num_joints or num_frames
    total_detected_per_joint_real = torch.zeros((17, )).to(device)  # hardcoded
    total_num_per_joint_real = torch.zeros((17, )).to(device)  # hardcoded

    # jointly training
    joint_loader = zip(syn_train_loader, real_train_loader)

    for iter_idx, ((syn_images_batch, syn_proj_mats_batch, syn_joints_3d_gt_batch, syn_joints_3d_valid_batch, syn_joints_2d_gt_batch, syn_info_batch), \
                   (real_images_batch, real_proj_mats_batch, real_joints_3d_gt_batch, real_joints_3d_valid_batch, real_joints_2d_gt_batch, real_indexes)) \
                   in enumerate(joint_loader):

        opt.zero_grad()

        # train on syndata
        if syn_images_batch is None:
            continue

        syn_images_batch = syn_images_batch.to(device)
        syn_proj_mats_batch = syn_proj_mats_batch.to(device)
        syn_joints_3d_gt_batch = syn_joints_3d_gt_batch.to(device)
        syn_joints_3d_valid_batch = syn_joints_3d_valid_batch.to(device)
        syn_joints_2d_gt_batch = syn_joints_2d_gt_batch.to(device)

        syn_joints_3d_pred, syn_joints_2d_pred, syn_heatmaps_pred, syn_confidences_pred, _ = model(
            syn_images_batch, syn_proj_mats_batch)

        if isinstance(criterion, HeatmapMSELoss):
            syn_loss = criterion(syn_heatmaps_pred, syn_joints_2d_gt_batch)
        else:
            syn_loss = criterion(syn_joints_3d_pred, syn_joints_3d_gt_batch,
                                 syn_joints_3d_valid_batch)

        # train on h36m
        if real_images_batch is None:
            continue

        real_images_batch = real_images_batch.to(device)
        if real_proj_mats_batch is not None:
            real_proj_mats_batch = real_proj_mats_batch.to(device)
            real_joints_3d_gt_batch = real_joints_3d_gt_batch.to(device)
            real_joints_3d_valid_batch = real_joints_3d_valid_batch.to(device)
        real_joints_2d_gt_batch = real_joints_2d_gt_batch.to(device)

        real_joints_3d_pred, real_joints_2d_pred, real_heatmaps_pred, real_confidences_pred, _ = model(
            real_images_batch, real_proj_mats_batch)

        pseudo_labels = np.load("pseudo_labels/%s_train.npy" %
                                config.dataset.type,
                                allow_pickle=True).item()  # load pseudo labels
        # p = 0.2 * (e // 10 + 1) # percentage
        p = 0.2  # hardcoded
        score_thresh = consistency.get_score_thresh(pseudo_labels,
                                                    p,
                                                    separate=True)
        real_joints_2d_pl_batch, real_joints_2d_valid_batch = \
                                 consistency.get_pseudo_labels(pseudo_labels, real_indexes, real_images_batch.shape[1], score_thresh)
        real_joints_2d_pl_batch = real_joints_2d_pl_batch.to(device)
        real_joints_2d_valid_batch = real_joints_2d_valid_batch.to(device)
        """
        # debug
        if iter_idx == 0:
            for batch_idx in range(h36m_joints_2d_gt_batch.shape[0]):
                for view_idx in range(h36m_joints_2d_gt_batch.shape[1]):
                    for j in range(h36m_joints_2d_gt_batch.shape[2]):
                        plt.imshow(h36m_images_batch[batch_idx, view_idx, 0, :, :].detach().cpu().numpy())
                        plt.scatter(h36m_joints_2d_gt_batch[batch_idx, view_idx, j, 0].detach().cpu().numpy(), \
                                h36m_joints_2d_gt_batch[batch_idx, view_idx, j, 1].detach().cpu().numpy(), \
                                s=10, color="red")
                        plt.xlabel("%s" \
                                % h36m_joints_2d_valid_batch[batch_idx, view_idx, j, 0].detach().cpu().numpy())
                        plt.savefig("hm/%d_%d_%d.png" % (batch_idx, view_idx, j))
                        plt.close()
        """

        if isinstance(criterion, HeatmapMSELoss):
            real_loss = criterion(real_heatmaps_pred, real_joints_2d_pl_batch,
                                  real_joints_2d_valid_batch)
            # real_loss = criterion(real_heatmaps_pred, real_joints_2d_gt_batch, real_joints_2d_valid_batch)
        else:
            raise ValueError(
                "Please use 2D Heatmap Loss for training on real dataset!")

        # optimize
        loss = syn_loss + gamma * real_loss
        loss.backward()
        opt.step()

        # evaluate on syndata
        syn_detected, syn_error, syn_num_samples, _, _\
                      = utils_eval.eval_one_batch(metric, syn_joints_3d_pred, syn_joints_2d_pred, \
                                                  syn_proj_mats_batch, syn_joints_3d_gt_batch, syn_joints_3d_valid_batch, \
                                                  syn_joints_2d_gt_batch)

        total_train_loss_syn += syn_num_samples * syn_loss.item()
        total_detected_syn += syn_detected
        total_error_syn += syn_num_samples * syn_error
        total_samples_syn += syn_num_samples

        # evaluate on h36m
        real_detected, real_error, real_num_samples, real_detected_per_joint, real_num_per_joint \
                       = utils_eval.eval_one_batch(metric, real_joints_3d_pred, real_joints_2d_pred, \
                                                   real_proj_mats_batch, real_joints_3d_gt_batch, real_joints_3d_valid_batch, \
                                                   real_joints_2d_gt_batch)

        total_train_loss_real += real_num_samples * real_loss.item()
        total_detected_real += real_detected
        total_error_real += real_num_samples * real_error
        total_samples_real += real_num_samples
        total_detected_per_joint_real += real_detected_per_joint  # size: num_joints
        total_num_per_joint_real += real_num_per_joint  # size: num_joints

        # logger
        if iter_idx % log_every_iters == log_every_iters - 1:
            logging_iter = iter_idx + 1 - log_every_iters
            mean_loss_logging_syn = total_train_loss_syn / total_samples_syn
            pck_acc_logging_syn = total_detected_syn / total_samples_syn
            mean_error_logging_syn = total_error_syn / total_samples_syn
            mean_loss_logging_real = total_train_loss_real / total_samples_real
            pck_acc_logging_real = total_detected_real / total_samples_real
            mean_error_logging_real = total_error_real / total_samples_real
            pck_per_joint_logging_real = total_detected_per_joint_real / total_num_per_joint_real  # size: num_joints
            print("epoch: %d, iter: %d" % (e, logging_iter))
            print("        (Syndata) train loss: %f, train acc: %.3f, train error: %.3f" \
                  % (mean_loss_logging_syn, pck_acc_logging_syn, mean_error_logging_syn))
            print("        (Real) train loss: %f, train acc: %.3f, train error: %.3f" \
                  % (mean_loss_logging_real, pck_acc_logging_real, mean_error_logging_real))

            if writer is not None:
                writer.add_scalar("train_loss/syndata/iter",
                                  mean_loss_logging_syn,
                                  e * iters_per_epoch + logging_iter)
                writer.add_scalar("train_pck/syndata/iter",
                                  pck_acc_logging_syn,
                                  e * iters_per_epoch + logging_iter)
                writer.add_scalar("train_error/syndata/iter",
                                  mean_error_logging_syn,
                                  e * iters_per_epoch + logging_iter)
                writer.add_scalar("train_loss/real/iter",
                                  mean_loss_logging_real,
                                  e * iters_per_epoch + logging_iter)
                writer.add_scalar("train_pck/real/iter", pck_acc_logging_real,
                                  e * iters_per_epoch + logging_iter)
                writer.add_scalar("train_error/real/iter",
                                  mean_error_logging_real,
                                  e * iters_per_epoch + logging_iter)
                for jnt_i in range(len(pck_per_joint_logging_real)):
                    writer.add_scalar("train_pck/real/iter/joint_%d" % jnt_i,
                                      pck_per_joint_logging_real[jnt_i],
                                      e * iters_per_epoch + logging_iter)

        # save images
        if iter_idx % vis_every_iters == 0:
            vis_iter = iter_idx
            # visualize first sample in batch
            if writer is not None:
                # joints_vis_syn = visualize.visualize_pred(syn_images_batch[0], syn_proj_mats_batch[0], syn_joints_3d_gt_batch[0], \
                #                                           syn_joints_3d_pred[0], syn_joints_2d_pred[0])
                joints_vis_syn = visualize.visualize_pred_2D(
                    syn_images_batch[0], syn_joints_2d_gt_batch[0],
                    syn_joints_2d_pred[0])
                writer.add_image("joints/syndata/iter",
                                 joints_vis_syn.transpose(2, 0, 1),
                                 global_step=e * iters_per_epoch + vis_iter)
                # joints_vis_h36m = visualize.visualize_pred(h36m_images_batch[0], h36m_proj_mats_batch[0], h36m_joints_3d_gt_batch[0], \
                #                                            h36m_joints_3d_pred[0], h36m_joints_2d_pred[0])
                joints_vis_real = visualize.visualize_pred_2D(
                    real_images_batch[0], real_joints_2d_gt_batch[0],
                    real_joints_2d_pred[0])
                writer.add_image("joints/real/iter",
                                 joints_vis_real.transpose(2, 0, 1),
                                 global_step=e * iters_per_epoch + vis_iter)

                vis_joint = (iter_idx // vis_every_iters) % 16
                heatmap_vis_syn = visualize.visualize_heatmap(syn_images_batch[0], syn_joints_2d_gt_batch[0], \
                                                              syn_heatmaps_pred[0], vis_joint=vis_joint)
                writer.add_image("heatmap/syndata/joint_%d/iter" % vis_joint,
                                 heatmap_vis_syn.transpose(2, 0, 1),
                                 global_step=e * iters_per_epoch + vis_iter)
                heatmap_vis_real = visualize.visualize_heatmap(real_images_batch[0], real_joints_2d_gt_batch[0], \
                                                               real_heatmaps_pred[0], vis_joint=vis_joint)
                writer.add_image("heatmap/real/joint_%d/iter" % vis_joint,
                                 heatmap_vis_real.transpose(2, 0, 1),
                                 global_step=e * iters_per_epoch + vis_iter)

    # save logging per epoch to tensorboard
    mean_loss_syn = total_train_loss_syn / total_samples_syn
    pck_acc_syn = total_detected_syn / total_samples_syn
    mean_error_syn = total_error_syn / total_samples_syn
    mean_loss_real = total_train_loss_real / total_samples_real
    pck_acc_real = total_detected_real / total_samples_real
    mean_error_real = total_error_real / total_samples_real
    pck_per_joint_real = total_detected_per_joint_real / total_num_per_joint_real  # size: num_joints
    if writer is not None:
        writer.add_scalar("train_loss/syndata/epoch", mean_loss_syn, e)
        writer.add_scalar("train_pck/syndata/epoch", pck_acc_syn, e)
        writer.add_scalar("train_error/syndata/epoch", mean_error_syn, e)
        writer.add_scalar("train_loss/real/epoch", mean_loss_real, e)
        writer.add_scalar("train_pck/real/epoch", pck_acc_real, e)
        writer.add_scalar("train_error/real/epoch", mean_error_real, e)
        for jnt_i in range(len(pck_per_joint_real)):
            writer.add_scalar("train_pck/real/epoch/joint_%d" % jnt_i,
                              pck_per_joint_real[jnt_i], e)

    return mean_loss_syn, pck_acc_syn, mean_error_syn, \
           mean_loss_real, pck_acc_real, mean_error_real
Example #3
0
def generate_features(config, model, dataloader, device, label_path,
                      write_path):

    if os.path.exists(write_path):
        print("File %s already exists" % write_path)
        return

    num_joints = config.model.backbone.num_joints

    # fill in parameters
    retval = {
        'dataset': config.dataset.type,
        'split': "train",
        'image_shape': config.dataset.image_shape,
        'num_joints': config.model.backbone.num_joints,
        'scale_bbox': config.dataset.train.scale_bbox,
        'retain_every_n_frames': config.dataset.train.retain_every_n_frames,
        'pseudo_label_path': label_path
    }
    feats_dtype = np.dtype([('data_idx', np.int32), ('view_idx', np.int8),
                            ('feats', np.float32, (num_joints, 256))])
    retval['features'] = []

    # re-configure data loader
    dataloader.shuffle = False

    # load pseudo labels
    pseudo_labels = np.load(label_path, allow_pickle=True).item()

    print("Generating feature vectors...")
    print("Estimated number of iterations is: %d" %
          round(dataloader.dataset.__len__() / dataloader.batch_size))
    model = model.to(device)
    model.eval()
    with torch.no_grad():
        for iter_idx, (images_batch, _, _, _, _,
                       indexes) in enumerate(dataloader):
            if images_batch is None:
                continue

            images_batch = images_batch.to(device)

            batch_size = images_batch.shape[0]
            num_views = images_batch.shape[1]
            image_shape = images_batch.shape[3:]
            assert batch_size == len(indexes)

            _, _, _, _, features = model(images_batch, None)
            feature_shape = features.shape[3:]
            ratio_h = feature_shape[0] / image_shape[0]
            ratio_w = feature_shape[1] / image_shape[1]
            joints_2d_batch, _ = consistency.get_pseudo_labels(
                pseudo_labels, indexes, num_views, 0)
            joints_2d_batch = joints_2d_batch.cpu().numpy()
            feat_xs = (joints_2d_batch[:, :, :, 0] * ratio_w).astype(np.int32)
            feat_ys = (joints_2d_batch[:, :, :, 1] * ratio_h).astype(np.int32)

            # fill in pseudo labels
            for batch_idx, data_idx in enumerate(indexes):
                for view_idx in range(num_views):
                    feat_x = feat_xs[batch_idx, view_idx, :]
                    feat_y = feat_ys[batch_idx, view_idx, :]
                    feats_segment = np.empty(1, dtype=feats_dtype)
                    feats_segment['data_idx'] = data_idx
                    feats_segment['view_idx'] = view_idx
                    feats_segment['feats'] = np.transpose(
                        features[batch_idx, view_idx, :, feat_y,
                                 feat_x].cpu().numpy())

                    retval['features'].append(feats_segment)

    retval['features'] = np.concatenate(retval['features'])
    assert retval['features'].ndim == 1
    print("Total number of images in the dataset: ", len(retval['features']))

    # save pseudo labels
    save_folder = "feats"
    os.makedirs(save_folder, exist_ok=True)

    print("Saving features to %s" % write_path)
    np.save(write_path, retval)
    print("Done.")
Example #4
0
def plot_tSNE(dataloader, device, label_path, feat_path):
    # re-configure data loader
    dataloader.shuffle = False

    # load pseudo labels and their features
    pseudo_labels = np.load(label_path, allow_pickle=True).item()
    feats_npy = np.load(feat_path, allow_pickle=True).item()
    num_joints = pseudo_labels['num_joints']

    # filtered features and detections in pseudo labels
    features_pl = []
    detections_pl = []
    for jnt_idx in range(num_joints):
        features_pl.append(np.empty((0, 256)))
        detections_pl.append(np.empty(0, ).astype(np.bool))

    error_thresh = 20
    p = 0.2
    score_thresh = consistency.get_score_thresh(pseudo_labels,
                                                p,
                                                separate=True)
    for iter_idx, (images_batch, _, _, _, joints_2d_gt_batch,
                   indexes) in enumerate(dataloader):
        if images_batch is None:
            continue

        joints_2d_pseudo, joints_2d_valid_batch = \
                consistency.get_pseudo_labels(pseudo_labels, indexes, images_batch.shape[1], score_thresh)

        features_batch = get_features(feats_npy, indexes,
                                      images_batch.shape[1],
                                      joints_2d_valid_batch)
        detections = (torch.norm(joints_2d_pseudo - joints_2d_gt_batch,
                                 dim=-1,
                                 keepdim=True) < error_thresh)

        for jnt_idx in range(num_joints):
            detections_jnt = detections[:, :, jnt_idx, :][
                joints_2d_valid_batch[:, :, jnt_idx, :]]
            features_pl[jnt_idx] = np.vstack(
                (features_pl[jnt_idx], features_batch[jnt_idx]))
            assert len(features_batch[jnt_idx]) == len(detections_jnt)
            detections_pl[jnt_idx] = np.concatenate(
                (detections_pl[jnt_idx], detections_jnt))

    # t-SNE visualization for each joint pseudo label
    for jnt_idx in range(num_joints):
        print("Joint %d:" % jnt_idx)
        time_start = time.time()
        tsne = TSNE(n_components=2)
        tsne_results = tsne.fit_transform(features_pl[jnt_idx])
        print("t-SNE done. Time elapsed: %f secs" % (time.time() - time_start))
        # plot
        pos = tsne_results[detections_pl[jnt_idx], :]
        neg = tsne_results[~detections_pl[jnt_idx], :]
        plt.title("t-SNE visualization")
        plt.scatter(x=pos[:, 0], y=pos[:, 1], alpha=0.3, label="inlier")
        plt.scatter(x=neg[:, 0], y=neg[:, 1], alpha=0.3, label="outlier")
        plt.legend()
        plt.savefig("figs/tsne_pseudo_labels_joint_%d.png" % jnt_idx)
        plt.close()
Example #5
0
def eval_pseudo_labels(dataset='human36m',
                       p=0.2,
                       separate=True,
                       triangulate=False):  # hardcoded

    if dataset == 'mpii' and triangulate:
        raise ValueError(
            "MPII dataset is not multiview, please use multivew dataset.")

    print("Loading data ...")
    if dataset == 'human36m':
        train_set = dataset = Human36MMultiViewDataset(
            h36m_root=
            "../learnable-triangulation-pytorch/data/human36m/processed/",
            train=True,
            image_shape=[384, 384],
            labels_path=
            "../learnable-triangulation-pytorch/data/human36m/extra/human36m-multiview-labels-GTbboxes.npy",
            with_damaged_actions=True,
            retain_every_n_frames=10,
            scale_bbox=1.6,
            kind="human36m",
            undistort_images=True,
            ignore_cameras=[],
            crop=True,
        )
        train_loader = datasets_utils.human36m_loader(train_set, \
                                                      # batch_size=64, \
                                                      batch_size=4, \
                                                      shuffle=False, \
                                                      num_workers=4)
        pseudo_labels = np.load(
            "pseudo_labels/human36m_train_every_10_frames.npy",
            allow_pickle=True).item()
        thresh = 1

    elif dataset == 'mpii':
        train_set = Mpii(
            image_path="../mpii_images",
            anno_path="../pytorch-pose/data/mpii/mpii_annotations.json",
            inp_res=384,
            out_res=96,
            is_train=True)
        train_loader = datasets_utils.mpii_loader(train_set, \
                                                  batch_size=256, \
                                                  shuffle=False, \
                                                  num_workers=4)
        pseudo_labels = np.load("pseudo_labels/mpii_train.npy",
                                allow_pickle=True).item()
        thresh = 0.5

    print("Data loaded.")
    score_thresh = consistency.get_score_thresh(pseudo_labels,
                                                p,
                                                separate=separate)

    total_joints = 0
    total_detected_pck = 0
    total_detected_pckh = 0
    errors = torch.empty(0)
    for iter_idx, (images_batch, proj_mats_batch, joints_3d_gt_batch,
                   joints_3d_valid_batch, joints_2d_gt_batch,
                   indexes) in enumerate(train_loader):
        print(iter_idx)
        if images_batch is None:
            continue

        joints_2d_pseudo, joints_2d_valid_batch = \
                          consistency.get_pseudo_labels(pseudo_labels, indexes, images_batch.shape[1], score_thresh)

        if triangulate:
            num_valid_before = joints_2d_valid_batch.sum()
            joints_2d_pseudo, joints_2d_valid_batch = triangulate_pseudo_labels(
                proj_mats_batch, joints_2d_pseudo, joints_2d_valid_batch)
            num_valid_after = joints_2d_valid_batch.sum()
            print("Number of valid labels: before: %d, after: %d" %
                  (num_valid_before, num_valid_after))

        detected_pck, num_jnts, _, _ = PCK()(joints_2d_pseudo,
                                             joints_2d_gt_batch,
                                             joints_2d_valid_batch)
        detected_pckh, _, _, _ = PCKh(thresh=thresh)(joints_2d_pseudo,
                                                     joints_2d_gt_batch,
                                                     joints_2d_valid_batch)
        diff = torch.sqrt(
            torch.sum((joints_2d_pseudo - joints_2d_gt_batch)**2,
                      dim=-1,
                      keepdims=True))
        error_2d = diff[joints_2d_valid_batch]
        errors = torch.cat((errors, error_2d))

        total_joints += num_jnts
        total_detected_pck += detected_pck
        total_detected_pckh += detected_pckh

    errors = errors.cpu().numpy()
    print("PCK:", total_detected_pck / total_joints)
    print("PCKh:", total_detected_pckh / total_joints)
    print("Error(2D):", errors.mean())

    # plot histogram
    plt.hist(errors, bins=100, density=True)
    plt.title("2D error distribution in pseudo labels")
    plt.xlabel("2D error (in pixel)")
    plt.ylabel("density")
    plt.savefig("figs/errors_pseudo_labels_%s.png" % dataset)
    plt.close()