def test_one_scene_pseudo(pseudo_labels, dataloader, save_folder, start_index): os.makedirs(save_folder, exist_ok=True) for iter_idx, (images_batch, proj_mats_batch, joints_3d_gt_batch, joints_3d_valid_batch, joints_2d_gt_batch, indexes) \ in enumerate(dataloader): if images_batch is None: continue batch_size = images_batch.shape[0] indexes_abs = [idx + start_index for idx in indexes] p = 0.2 # percentage score_thresh = consistency.get_score_thresh(pseudo_labels, p, separate=True) joints_2d_pl_batch, joints_2d_valid_batch = \ consistency.get_pseudo_labels(pseudo_labels, indexes_abs, images_batch.shape[1], score_thresh) for batch_i in range(batch_size): vis = visualize.visualize_pseudo_labels(images_batch[batch_i], joints_2d_pl_batch[batch_i], joints_2d_valid_batch[batch_i]) im = Image.fromarray(vis) print(save_folder) print("%06d.png" % (iter_idx * batch_size + batch_i)) img_path = os.path.join(save_folder, "%06d.png" % (iter_idx * batch_size + batch_i)) im.save(img_path)
def train_one_epoch_ssl(config, model, syn_train_loader, real_train_loader, \ criterion, metric, opt, e, device, \ checkpoint_dir, writer=None, \ gamma=10, \ log_every_iters=1, vis_every_iters=1): model.train() batch_size = syn_train_loader.batch_size iters_per_epoch = round(min(syn_train_loader.dataset.__len__() / syn_train_loader.batch_size, \ real_train_loader.dataset.__len__() / real_train_loader.batch_size)) print("Estimated iterations per epoch is %d." % iters_per_epoch) total_train_loss_syn = 0 total_detected_syn = 0 total_error_syn = 0 total_samples_syn = 0 # num_joints or num_frames total_train_loss_real = 0 total_detected_real = 0 total_error_real = 0 total_samples_real = 0 # num_joints or num_frames total_detected_per_joint_real = torch.zeros((17, )).to(device) # hardcoded total_num_per_joint_real = torch.zeros((17, )).to(device) # hardcoded # jointly training joint_loader = zip(syn_train_loader, real_train_loader) for iter_idx, ((syn_images_batch, syn_proj_mats_batch, syn_joints_3d_gt_batch, syn_joints_3d_valid_batch, syn_joints_2d_gt_batch, syn_info_batch), \ (real_images_batch, real_proj_mats_batch, real_joints_3d_gt_batch, real_joints_3d_valid_batch, real_joints_2d_gt_batch, real_indexes)) \ in enumerate(joint_loader): opt.zero_grad() # train on syndata if syn_images_batch is None: continue syn_images_batch = syn_images_batch.to(device) syn_proj_mats_batch = syn_proj_mats_batch.to(device) syn_joints_3d_gt_batch = syn_joints_3d_gt_batch.to(device) syn_joints_3d_valid_batch = syn_joints_3d_valid_batch.to(device) syn_joints_2d_gt_batch = syn_joints_2d_gt_batch.to(device) syn_joints_3d_pred, syn_joints_2d_pred, syn_heatmaps_pred, syn_confidences_pred, _ = model( syn_images_batch, syn_proj_mats_batch) if isinstance(criterion, HeatmapMSELoss): syn_loss = criterion(syn_heatmaps_pred, syn_joints_2d_gt_batch) else: syn_loss = criterion(syn_joints_3d_pred, syn_joints_3d_gt_batch, syn_joints_3d_valid_batch) # train on h36m if real_images_batch is None: continue real_images_batch = real_images_batch.to(device) if real_proj_mats_batch is not None: real_proj_mats_batch = real_proj_mats_batch.to(device) real_joints_3d_gt_batch = real_joints_3d_gt_batch.to(device) real_joints_3d_valid_batch = real_joints_3d_valid_batch.to(device) real_joints_2d_gt_batch = real_joints_2d_gt_batch.to(device) real_joints_3d_pred, real_joints_2d_pred, real_heatmaps_pred, real_confidences_pred, _ = model( real_images_batch, real_proj_mats_batch) pseudo_labels = np.load("pseudo_labels/%s_train.npy" % config.dataset.type, allow_pickle=True).item() # load pseudo labels # p = 0.2 * (e // 10 + 1) # percentage p = 0.2 # hardcoded score_thresh = consistency.get_score_thresh(pseudo_labels, p, separate=True) real_joints_2d_pl_batch, real_joints_2d_valid_batch = \ consistency.get_pseudo_labels(pseudo_labels, real_indexes, real_images_batch.shape[1], score_thresh) real_joints_2d_pl_batch = real_joints_2d_pl_batch.to(device) real_joints_2d_valid_batch = real_joints_2d_valid_batch.to(device) """ # debug if iter_idx == 0: for batch_idx in range(h36m_joints_2d_gt_batch.shape[0]): for view_idx in range(h36m_joints_2d_gt_batch.shape[1]): for j in range(h36m_joints_2d_gt_batch.shape[2]): plt.imshow(h36m_images_batch[batch_idx, view_idx, 0, :, :].detach().cpu().numpy()) plt.scatter(h36m_joints_2d_gt_batch[batch_idx, view_idx, j, 0].detach().cpu().numpy(), \ h36m_joints_2d_gt_batch[batch_idx, view_idx, j, 1].detach().cpu().numpy(), \ s=10, color="red") plt.xlabel("%s" \ % h36m_joints_2d_valid_batch[batch_idx, view_idx, j, 0].detach().cpu().numpy()) plt.savefig("hm/%d_%d_%d.png" % (batch_idx, view_idx, j)) plt.close() """ if isinstance(criterion, HeatmapMSELoss): real_loss = criterion(real_heatmaps_pred, real_joints_2d_pl_batch, real_joints_2d_valid_batch) # real_loss = criterion(real_heatmaps_pred, real_joints_2d_gt_batch, real_joints_2d_valid_batch) else: raise ValueError( "Please use 2D Heatmap Loss for training on real dataset!") # optimize loss = syn_loss + gamma * real_loss loss.backward() opt.step() # evaluate on syndata syn_detected, syn_error, syn_num_samples, _, _\ = utils_eval.eval_one_batch(metric, syn_joints_3d_pred, syn_joints_2d_pred, \ syn_proj_mats_batch, syn_joints_3d_gt_batch, syn_joints_3d_valid_batch, \ syn_joints_2d_gt_batch) total_train_loss_syn += syn_num_samples * syn_loss.item() total_detected_syn += syn_detected total_error_syn += syn_num_samples * syn_error total_samples_syn += syn_num_samples # evaluate on h36m real_detected, real_error, real_num_samples, real_detected_per_joint, real_num_per_joint \ = utils_eval.eval_one_batch(metric, real_joints_3d_pred, real_joints_2d_pred, \ real_proj_mats_batch, real_joints_3d_gt_batch, real_joints_3d_valid_batch, \ real_joints_2d_gt_batch) total_train_loss_real += real_num_samples * real_loss.item() total_detected_real += real_detected total_error_real += real_num_samples * real_error total_samples_real += real_num_samples total_detected_per_joint_real += real_detected_per_joint # size: num_joints total_num_per_joint_real += real_num_per_joint # size: num_joints # logger if iter_idx % log_every_iters == log_every_iters - 1: logging_iter = iter_idx + 1 - log_every_iters mean_loss_logging_syn = total_train_loss_syn / total_samples_syn pck_acc_logging_syn = total_detected_syn / total_samples_syn mean_error_logging_syn = total_error_syn / total_samples_syn mean_loss_logging_real = total_train_loss_real / total_samples_real pck_acc_logging_real = total_detected_real / total_samples_real mean_error_logging_real = total_error_real / total_samples_real pck_per_joint_logging_real = total_detected_per_joint_real / total_num_per_joint_real # size: num_joints print("epoch: %d, iter: %d" % (e, logging_iter)) print(" (Syndata) train loss: %f, train acc: %.3f, train error: %.3f" \ % (mean_loss_logging_syn, pck_acc_logging_syn, mean_error_logging_syn)) print(" (Real) train loss: %f, train acc: %.3f, train error: %.3f" \ % (mean_loss_logging_real, pck_acc_logging_real, mean_error_logging_real)) if writer is not None: writer.add_scalar("train_loss/syndata/iter", mean_loss_logging_syn, e * iters_per_epoch + logging_iter) writer.add_scalar("train_pck/syndata/iter", pck_acc_logging_syn, e * iters_per_epoch + logging_iter) writer.add_scalar("train_error/syndata/iter", mean_error_logging_syn, e * iters_per_epoch + logging_iter) writer.add_scalar("train_loss/real/iter", mean_loss_logging_real, e * iters_per_epoch + logging_iter) writer.add_scalar("train_pck/real/iter", pck_acc_logging_real, e * iters_per_epoch + logging_iter) writer.add_scalar("train_error/real/iter", mean_error_logging_real, e * iters_per_epoch + logging_iter) for jnt_i in range(len(pck_per_joint_logging_real)): writer.add_scalar("train_pck/real/iter/joint_%d" % jnt_i, pck_per_joint_logging_real[jnt_i], e * iters_per_epoch + logging_iter) # save images if iter_idx % vis_every_iters == 0: vis_iter = iter_idx # visualize first sample in batch if writer is not None: # joints_vis_syn = visualize.visualize_pred(syn_images_batch[0], syn_proj_mats_batch[0], syn_joints_3d_gt_batch[0], \ # syn_joints_3d_pred[0], syn_joints_2d_pred[0]) joints_vis_syn = visualize.visualize_pred_2D( syn_images_batch[0], syn_joints_2d_gt_batch[0], syn_joints_2d_pred[0]) writer.add_image("joints/syndata/iter", joints_vis_syn.transpose(2, 0, 1), global_step=e * iters_per_epoch + vis_iter) # joints_vis_h36m = visualize.visualize_pred(h36m_images_batch[0], h36m_proj_mats_batch[0], h36m_joints_3d_gt_batch[0], \ # h36m_joints_3d_pred[0], h36m_joints_2d_pred[0]) joints_vis_real = visualize.visualize_pred_2D( real_images_batch[0], real_joints_2d_gt_batch[0], real_joints_2d_pred[0]) writer.add_image("joints/real/iter", joints_vis_real.transpose(2, 0, 1), global_step=e * iters_per_epoch + vis_iter) vis_joint = (iter_idx // vis_every_iters) % 16 heatmap_vis_syn = visualize.visualize_heatmap(syn_images_batch[0], syn_joints_2d_gt_batch[0], \ syn_heatmaps_pred[0], vis_joint=vis_joint) writer.add_image("heatmap/syndata/joint_%d/iter" % vis_joint, heatmap_vis_syn.transpose(2, 0, 1), global_step=e * iters_per_epoch + vis_iter) heatmap_vis_real = visualize.visualize_heatmap(real_images_batch[0], real_joints_2d_gt_batch[0], \ real_heatmaps_pred[0], vis_joint=vis_joint) writer.add_image("heatmap/real/joint_%d/iter" % vis_joint, heatmap_vis_real.transpose(2, 0, 1), global_step=e * iters_per_epoch + vis_iter) # save logging per epoch to tensorboard mean_loss_syn = total_train_loss_syn / total_samples_syn pck_acc_syn = total_detected_syn / total_samples_syn mean_error_syn = total_error_syn / total_samples_syn mean_loss_real = total_train_loss_real / total_samples_real pck_acc_real = total_detected_real / total_samples_real mean_error_real = total_error_real / total_samples_real pck_per_joint_real = total_detected_per_joint_real / total_num_per_joint_real # size: num_joints if writer is not None: writer.add_scalar("train_loss/syndata/epoch", mean_loss_syn, e) writer.add_scalar("train_pck/syndata/epoch", pck_acc_syn, e) writer.add_scalar("train_error/syndata/epoch", mean_error_syn, e) writer.add_scalar("train_loss/real/epoch", mean_loss_real, e) writer.add_scalar("train_pck/real/epoch", pck_acc_real, e) writer.add_scalar("train_error/real/epoch", mean_error_real, e) for jnt_i in range(len(pck_per_joint_real)): writer.add_scalar("train_pck/real/epoch/joint_%d" % jnt_i, pck_per_joint_real[jnt_i], e) return mean_loss_syn, pck_acc_syn, mean_error_syn, \ mean_loss_real, pck_acc_real, mean_error_real
def generate_features(config, model, dataloader, device, label_path, write_path): if os.path.exists(write_path): print("File %s already exists" % write_path) return num_joints = config.model.backbone.num_joints # fill in parameters retval = { 'dataset': config.dataset.type, 'split': "train", 'image_shape': config.dataset.image_shape, 'num_joints': config.model.backbone.num_joints, 'scale_bbox': config.dataset.train.scale_bbox, 'retain_every_n_frames': config.dataset.train.retain_every_n_frames, 'pseudo_label_path': label_path } feats_dtype = np.dtype([('data_idx', np.int32), ('view_idx', np.int8), ('feats', np.float32, (num_joints, 256))]) retval['features'] = [] # re-configure data loader dataloader.shuffle = False # load pseudo labels pseudo_labels = np.load(label_path, allow_pickle=True).item() print("Generating feature vectors...") print("Estimated number of iterations is: %d" % round(dataloader.dataset.__len__() / dataloader.batch_size)) model = model.to(device) model.eval() with torch.no_grad(): for iter_idx, (images_batch, _, _, _, _, indexes) in enumerate(dataloader): if images_batch is None: continue images_batch = images_batch.to(device) batch_size = images_batch.shape[0] num_views = images_batch.shape[1] image_shape = images_batch.shape[3:] assert batch_size == len(indexes) _, _, _, _, features = model(images_batch, None) feature_shape = features.shape[3:] ratio_h = feature_shape[0] / image_shape[0] ratio_w = feature_shape[1] / image_shape[1] joints_2d_batch, _ = consistency.get_pseudo_labels( pseudo_labels, indexes, num_views, 0) joints_2d_batch = joints_2d_batch.cpu().numpy() feat_xs = (joints_2d_batch[:, :, :, 0] * ratio_w).astype(np.int32) feat_ys = (joints_2d_batch[:, :, :, 1] * ratio_h).astype(np.int32) # fill in pseudo labels for batch_idx, data_idx in enumerate(indexes): for view_idx in range(num_views): feat_x = feat_xs[batch_idx, view_idx, :] feat_y = feat_ys[batch_idx, view_idx, :] feats_segment = np.empty(1, dtype=feats_dtype) feats_segment['data_idx'] = data_idx feats_segment['view_idx'] = view_idx feats_segment['feats'] = np.transpose( features[batch_idx, view_idx, :, feat_y, feat_x].cpu().numpy()) retval['features'].append(feats_segment) retval['features'] = np.concatenate(retval['features']) assert retval['features'].ndim == 1 print("Total number of images in the dataset: ", len(retval['features'])) # save pseudo labels save_folder = "feats" os.makedirs(save_folder, exist_ok=True) print("Saving features to %s" % write_path) np.save(write_path, retval) print("Done.")
def plot_tSNE(dataloader, device, label_path, feat_path): # re-configure data loader dataloader.shuffle = False # load pseudo labels and their features pseudo_labels = np.load(label_path, allow_pickle=True).item() feats_npy = np.load(feat_path, allow_pickle=True).item() num_joints = pseudo_labels['num_joints'] # filtered features and detections in pseudo labels features_pl = [] detections_pl = [] for jnt_idx in range(num_joints): features_pl.append(np.empty((0, 256))) detections_pl.append(np.empty(0, ).astype(np.bool)) error_thresh = 20 p = 0.2 score_thresh = consistency.get_score_thresh(pseudo_labels, p, separate=True) for iter_idx, (images_batch, _, _, _, joints_2d_gt_batch, indexes) in enumerate(dataloader): if images_batch is None: continue joints_2d_pseudo, joints_2d_valid_batch = \ consistency.get_pseudo_labels(pseudo_labels, indexes, images_batch.shape[1], score_thresh) features_batch = get_features(feats_npy, indexes, images_batch.shape[1], joints_2d_valid_batch) detections = (torch.norm(joints_2d_pseudo - joints_2d_gt_batch, dim=-1, keepdim=True) < error_thresh) for jnt_idx in range(num_joints): detections_jnt = detections[:, :, jnt_idx, :][ joints_2d_valid_batch[:, :, jnt_idx, :]] features_pl[jnt_idx] = np.vstack( (features_pl[jnt_idx], features_batch[jnt_idx])) assert len(features_batch[jnt_idx]) == len(detections_jnt) detections_pl[jnt_idx] = np.concatenate( (detections_pl[jnt_idx], detections_jnt)) # t-SNE visualization for each joint pseudo label for jnt_idx in range(num_joints): print("Joint %d:" % jnt_idx) time_start = time.time() tsne = TSNE(n_components=2) tsne_results = tsne.fit_transform(features_pl[jnt_idx]) print("t-SNE done. Time elapsed: %f secs" % (time.time() - time_start)) # plot pos = tsne_results[detections_pl[jnt_idx], :] neg = tsne_results[~detections_pl[jnt_idx], :] plt.title("t-SNE visualization") plt.scatter(x=pos[:, 0], y=pos[:, 1], alpha=0.3, label="inlier") plt.scatter(x=neg[:, 0], y=neg[:, 1], alpha=0.3, label="outlier") plt.legend() plt.savefig("figs/tsne_pseudo_labels_joint_%d.png" % jnt_idx) plt.close()
def eval_pseudo_labels(dataset='human36m', p=0.2, separate=True, triangulate=False): # hardcoded if dataset == 'mpii' and triangulate: raise ValueError( "MPII dataset is not multiview, please use multivew dataset.") print("Loading data ...") if dataset == 'human36m': train_set = dataset = Human36MMultiViewDataset( h36m_root= "../learnable-triangulation-pytorch/data/human36m/processed/", train=True, image_shape=[384, 384], labels_path= "../learnable-triangulation-pytorch/data/human36m/extra/human36m-multiview-labels-GTbboxes.npy", with_damaged_actions=True, retain_every_n_frames=10, scale_bbox=1.6, kind="human36m", undistort_images=True, ignore_cameras=[], crop=True, ) train_loader = datasets_utils.human36m_loader(train_set, \ # batch_size=64, \ batch_size=4, \ shuffle=False, \ num_workers=4) pseudo_labels = np.load( "pseudo_labels/human36m_train_every_10_frames.npy", allow_pickle=True).item() thresh = 1 elif dataset == 'mpii': train_set = Mpii( image_path="../mpii_images", anno_path="../pytorch-pose/data/mpii/mpii_annotations.json", inp_res=384, out_res=96, is_train=True) train_loader = datasets_utils.mpii_loader(train_set, \ batch_size=256, \ shuffle=False, \ num_workers=4) pseudo_labels = np.load("pseudo_labels/mpii_train.npy", allow_pickle=True).item() thresh = 0.5 print("Data loaded.") score_thresh = consistency.get_score_thresh(pseudo_labels, p, separate=separate) total_joints = 0 total_detected_pck = 0 total_detected_pckh = 0 errors = torch.empty(0) for iter_idx, (images_batch, proj_mats_batch, joints_3d_gt_batch, joints_3d_valid_batch, joints_2d_gt_batch, indexes) in enumerate(train_loader): print(iter_idx) if images_batch is None: continue joints_2d_pseudo, joints_2d_valid_batch = \ consistency.get_pseudo_labels(pseudo_labels, indexes, images_batch.shape[1], score_thresh) if triangulate: num_valid_before = joints_2d_valid_batch.sum() joints_2d_pseudo, joints_2d_valid_batch = triangulate_pseudo_labels( proj_mats_batch, joints_2d_pseudo, joints_2d_valid_batch) num_valid_after = joints_2d_valid_batch.sum() print("Number of valid labels: before: %d, after: %d" % (num_valid_before, num_valid_after)) detected_pck, num_jnts, _, _ = PCK()(joints_2d_pseudo, joints_2d_gt_batch, joints_2d_valid_batch) detected_pckh, _, _, _ = PCKh(thresh=thresh)(joints_2d_pseudo, joints_2d_gt_batch, joints_2d_valid_batch) diff = torch.sqrt( torch.sum((joints_2d_pseudo - joints_2d_gt_batch)**2, dim=-1, keepdims=True)) error_2d = diff[joints_2d_valid_batch] errors = torch.cat((errors, error_2d)) total_joints += num_jnts total_detected_pck += detected_pck total_detected_pckh += detected_pckh errors = errors.cpu().numpy() print("PCK:", total_detected_pck / total_joints) print("PCKh:", total_detected_pckh / total_joints) print("Error(2D):", errors.mean()) # plot histogram plt.hist(errors, bins=100, density=True) plt.title("2D error distribution in pseudo labels") plt.xlabel("2D error (in pixel)") plt.ylabel("density") plt.savefig("figs/errors_pseudo_labels_%s.png" % dataset) plt.close()