Exemplo n.º 1
0
def one_epoch(model, criterion, opt, config, dataloader, device, epoch, n_iters_total=0, is_train=True, caption='', master=False, experiment_dir=None, writer=None):
    name = "train" if is_train else "val"
    model_type = config.model.name

    if is_train:
        model.train()
    else:
        model.eval()

    metric_dict = defaultdict(list)

    results = defaultdict(list)

    # used to turn on/off gradients
    grad_context = torch.autograd.enable_grad if is_train else torch.no_grad
    with grad_context():
        end = time.time()

        iterator = enumerate(dataloader)
        if is_train and config.opt.n_iters_per_epoch is not None:
            iterator = islice(iterator, config.opt.n_iters_per_epoch)

        for iter_i, batch in iterator:
            with autograd.detect_anomaly():
                # measure data loading time
                data_time = time.time() - end

                if batch is None:
                    print("Found None batch")
                    continue

                images_batch, keypoints_3d_gt, keypoints_3d_validity_gt, proj_matricies_batch = dataset_utils.prepare_batch(batch, device, config)

                keypoints_2d_pred, cuboids_pred, base_points_pred = None, None, None
                if model_type == "alg" or model_type == "ransac":
                    keypoints_3d_pred, keypoints_2d_pred, heatmaps_pred, confidences_pred = model(images_batch, proj_matricies_batch, batch)
                elif model_type == "vol":
                    keypoints_3d_pred, heatmaps_pred, volumes_pred, confidences_pred, cuboids_pred, coord_volumes_pred, base_points_pred = model(images_batch, proj_matricies_batch, batch)

                batch_size, n_views, image_shape = images_batch.shape[0], images_batch.shape[1], tuple(images_batch.shape[3:])
                n_joints = keypoints_3d_pred.shape[1]

                keypoints_3d_binary_validity_gt = (keypoints_3d_validity_gt > 0.0).type(torch.float32)

                scale_keypoints_3d = config.opt.scale_keypoints_3d if hasattr(config.opt, "scale_keypoints_3d") else 1.0

                # 1-view case
                if n_views == 1:
                    if config.kind == "human36m":
                        base_joint = 6
                    elif config.kind == "coco":
                        base_joint = 11

                    keypoints_3d_gt_transformed = keypoints_3d_gt.clone()
                    keypoints_3d_gt_transformed[:, torch.arange(n_joints) != base_joint] -= keypoints_3d_gt_transformed[:, base_joint:base_joint + 1]
                    keypoints_3d_gt = keypoints_3d_gt_transformed

                    keypoints_3d_pred_transformed = keypoints_3d_pred.clone()
                    keypoints_3d_pred_transformed[:, torch.arange(n_joints) != base_joint] -= keypoints_3d_pred_transformed[:, base_joint:base_joint + 1]
                    keypoints_3d_pred = keypoints_3d_pred_transformed

                # calculate loss
                total_loss = 0.0
                loss = criterion(keypoints_3d_pred * scale_keypoints_3d, keypoints_3d_gt * scale_keypoints_3d, keypoints_3d_binary_validity_gt)
                total_loss += loss
                metric_dict[f'{config.opt.criterion}'].append(loss.item())

                # volumetric ce loss
                use_volumetric_ce_loss = config.opt.use_volumetric_ce_loss if hasattr(config.opt, "use_volumetric_ce_loss") else False
                if use_volumetric_ce_loss:
                    volumetric_ce_criterion = VolumetricCELoss()

                    loss = volumetric_ce_criterion(coord_volumes_pred, volumes_pred, keypoints_3d_gt, keypoints_3d_binary_validity_gt)
                    metric_dict['volumetric_ce_loss'].append(loss.item())

                    weight = config.opt.volumetric_ce_loss_weight if hasattr(config.opt, "volumetric_ce_loss_weight") else 1.0
                    total_loss += weight * loss

                metric_dict['total_loss'].append(total_loss.item())
                print('epoch: {}, {}: {}, volumetric_ce_loss: {}, total_loss: {}'.format(epoch, config.opt.criterion, metric_dict[f'{config.opt.criterion}'][-1], metric_dict['volumetric_ce_loss'][-1], total_loss.item()))

                if is_train:
                    opt.zero_grad()
                    total_loss.backward()

                    if hasattr(config.opt, "grad_clip"):
                        torch.nn.utils.clip_grad_norm_(model.parameters(), config.opt.grad_clip / config.opt.lr)

                    metric_dict['grad_norm_times_lr'].append(config.opt.lr * misc.calc_gradient_norm(filter(lambda x: x[1].requires_grad, model.named_parameters())))

                    opt.step()

                # calculate metrics
                l2 = KeypointsL2Loss()(keypoints_3d_pred * scale_keypoints_3d, keypoints_3d_gt * scale_keypoints_3d, keypoints_3d_binary_validity_gt)
                metric_dict['l2'].append(l2.item())

                # base point l2
                if base_points_pred is not None:
                    base_point_l2_list = []
                    for batch_i in range(batch_size):
                        base_point_pred = base_points_pred[batch_i]

                        if config.model.kind == "coco":
                            base_point_gt = (keypoints_3d_gt[batch_i, 11, :3] + keypoints_3d[batch_i, 12, :3]) / 2
                        elif config.model.kind == "mpii":
                            base_point_gt = keypoints_3d_gt[batch_i, 6, :3]

                        base_point_l2_list.append(torch.sqrt(torch.sum((base_point_pred * scale_keypoints_3d - base_point_gt * scale_keypoints_3d) ** 2)).item())

                    base_point_l2 = 0.0 if len(base_point_l2_list) == 0 else np.mean(base_point_l2_list)
                    metric_dict['base_point_l2'].append(base_point_l2)

                # save answers for evalulation
                if not is_train:
                    results['keypoints_3d'].append(keypoints_3d_pred.detach().cpu().numpy())
                    results['indexes'].append(batch['indexes'])

                # plot visualization
                if master:
                    if n_iters_total % config.vis_freq == 0:# or total_l2.item() > 500.0:
                        vis_kind = config.kind
                        if (config.transfer_cmu_to_human36m if hasattr(config, "transfer_cmu_to_human36m") else False):
                            vis_kind = "coco"

                        for batch_i in range(min(batch_size, config.vis_n_elements)):
                            keypoints_vis = vis.visualize_batch(
                                images_batch, heatmaps_pred, keypoints_2d_pred, proj_matricies_batch,
                                keypoints_3d_gt, keypoints_3d_pred,
                                kind=vis_kind,
                                cuboids_batch=cuboids_pred,
                                confidences_batch=confidences_pred,
                                batch_index=batch_i, size=5,
                                max_n_cols=10
                            )
                            writer.add_image(f"{name}/keypoints_vis/{batch_i}", keypoints_vis.transpose(2, 0, 1), global_step=n_iters_total)

                            heatmaps_vis = vis.visualize_heatmaps(
                                images_batch, heatmaps_pred,
                                kind=vis_kind,
                                batch_index=batch_i, size=5,
                                max_n_rows=10, max_n_cols=10
                            )
                            writer.add_image(f"{name}/heatmaps/{batch_i}", heatmaps_vis.transpose(2, 0, 1), global_step=n_iters_total)

                            if model_type == "vol":
                                volumes_vis = vis.visualize_volumes(
                                    images_batch, volumes_pred, proj_matricies_batch,
                                    kind=vis_kind,
                                    cuboids_batch=cuboids_pred,
                                    batch_index=batch_i, size=5,
                                    max_n_rows=1, max_n_cols=16
                                )
                                writer.add_image(f"{name}/volumes/{batch_i}", volumes_vis.transpose(2, 0, 1), global_step=n_iters_total)

                    # dump weights to tensoboard
                    if n_iters_total % config.vis_freq == 0:
                        for p_name, p in model.named_parameters():
                            try:
                                writer.add_histogram(p_name, p.clone().cpu().data.numpy(), n_iters_total)
                            except ValueError as e:
                                print(e)
                                print(p_name, p)
                                exit()

                    # dump to tensorboard per-iter loss/metric stats
                    if is_train:
                        for title, value in metric_dict.items():
                            writer.add_scalar(f"{name}/{title}", value[-1], n_iters_total)

                    # measure elapsed time
                    batch_time = time.time() - end
                    end = time.time()

                    # dump to tensorboard per-iter time stats
                    writer.add_scalar(f"{name}/batch_time", batch_time, n_iters_total)
                    writer.add_scalar(f"{name}/data_time", data_time, n_iters_total)

                    # dump to tensorboard per-iter stats about sizes
                    writer.add_scalar(f"{name}/batch_size", batch_size, n_iters_total)
                    writer.add_scalar(f"{name}/n_views", n_views, n_iters_total)

                    n_iters_total += 1

    # calculate evaluation metrics
    if master:
        if not is_train:
            results['keypoints_3d'] = np.concatenate(results['keypoints_3d'], axis=0)
            results['indexes'] = np.concatenate(results['indexes'])

            try:
                scalar_metric, full_metric = dataloader.dataset.evaluate(results['keypoints_3d'])
            except Exception as e:
                print("Failed to evaluate. Reason: ", e)
                scalar_metric, full_metric = 0.0, {}

            metric_dict['dataset_metric'].append(scalar_metric) # mean per pose relative error in human36m
            print('epoch: {}, dataset_metric: {}'.format(epoch, scalar_metric))

            checkpoint_dir = os.path.join(experiment_dir, "checkpoints", "{:04}".format(epoch))
            os.makedirs(checkpoint_dir, exist_ok=True)

            # dump results
            with open(os.path.join(checkpoint_dir, "results.pkl"), 'wb') as fout:
                pickle.dump(results, fout)

            # dump full metric
            with open(os.path.join(checkpoint_dir, "metric.json".format(epoch)), 'w') as fout:
                json.dump(full_metric, fout, indent=4, sort_keys=True)

        # dump to tensorboard per-epoch stats
        for title, value in metric_dict.items():
            writer.add_scalar(f"{name}/{title}_epoch", np.mean(value), epoch)

    return n_iters_total
def one_epoch(model, criterion, opt, config, dataloader, device, epoch, n_iters_total=0, is_train=True, caption='', master=False, experiment_dir=None, writer=None):
    name = "train" if is_train else "val"
    model_type = config.model.name

    if is_train:
        model.train()
    else:
        model.eval()

    metric_dict = defaultdict(list)

    results = defaultdict(list)
    
    save_extra_data = config.save_extra_data if hasattr(config, "save_extra_data") else False
    
    if save_extra_data:
        extra_data = defaultdict(list)

    transfer_cmu_h36m = config.model.transfer_cmu_to_human36m if hasattr(config.model, "transfer_cmu_to_human36m") else False

    print("Transfer CMU to H36M: ", transfer_cmu_h36m)
    print("Using GT Pelvis position: ", config.model.use_gt_pelvis if hasattr(config.model, "use_gt_pelvis") else False)
    print("Using cameras: ", dataloader.dataset.choose_cameras if hasattr(dataloader.dataset, "choose_cameras") else False)
    print("Debug Mode: ", DEBUG)
    print("Training: ", is_train)

    train_eval_mode = "Train" if is_train else "Eval"

    # used to turn on/off gradients
    grad_context = torch.autograd.enable_grad if is_train else torch.no_grad
    with grad_context():
        end = time.time()

        iterator = enumerate(dataloader)

        if is_train and config.opt.n_iters_per_epoch is not None:
            iterator = islice(iterator, config.opt.n_iters_per_epoch)

        if not is_train and config.opt.n_iters_per_epoch_val is not None:
            iterator = islice(iterator, config.opt.n_iters_per_epoch_val)

        '''
        Data breakdown:
        - For each of the (max) 31 cameras in CMU dataset:
            - OpenCV Image: Numpy array [Note: likely cropped to smaller shape]
            - BBOX Detection for the image: (left, top, right, bottom) tuple
            - Camera: `Camera` object from `multiview.py`
        - Index: int
        - Keypoints (gt): NP Array, (17, 4)
        - Keypoints (pred): NP Array, (17, 4) [Note: may not be there]
        '''
        ignore_batch = [ ]

        for iter_i, batch in iterator:
            if not is_train and iter_i in ignore_batch:
                continue

            if True: # with autograd.detect_anomaly():
                # measure data loading time
                data_time = time.time() - end
                    
                if batch is None:
                    print(
                        f"[{train_eval_mode}, {epoch}] Found None batch: {iter_i}")
                    continue
                
                if DEBUG:                    
                    print(f"{train_eval_mode} batch {iter_i}...")
                    print(f"[{train_eval_mode}, {epoch}, {iter_i}] Preparing batch... ", end="")

                images_batch, keypoints_3d_gt, keypoints_3d_validity_gt, proj_matricies_batch = dataset_utils.prepare_batch(batch, device, config)

                if DEBUG: 
                    print("Prepared!")
                

                if DEBUG: 
                    print(f"[{train_eval_mode}, {epoch}, {iter_i}] Running {model_type} model... ", end="")

                keypoints_2d_pred, cuboids_pred, base_points_pred = None, None, None
                if model_type == "alg" or model_type == "ransac":
                    keypoints_3d_pred, keypoints_2d_pred, heatmaps_pred, confidences_pred = model(images_batch, proj_matricies_batch, batch)
                elif model_type == "vol":
                    keypoints_3d_pred, heatmaps_pred, volumes_pred, confidences_pred, cuboids_pred, coord_volumes_pred, base_points_pred = model(images_batch, proj_matricies_batch, batch)
                else:
                    raise NotImplementedError(f"Unknown model type {model_type}")

                if DEBUG:
                    print("Done!")

                # batch shape[2] is likely to be the number of channels
                # n_views is also the number of cameras being used in this batch 
                batch_size, n_views, image_shape = images_batch.shape[0], images_batch.shape[1], tuple(images_batch.shape[3:])
                n_joints = keypoints_3d_pred.shape[1]

                keypoints_3d_binary_validity_gt = (keypoints_3d_validity_gt > 0.0).type(torch.float32)

                # Due to differences in model used, it may be possible that the gt and pred keypoints have different scales
                # Set this difference in scaling in the config.yaml file
                scale_keypoints_3d = config.opt.scale_keypoints_3d if hasattr(config.opt, "scale_keypoints_3d") else 1.0
                scale_keypoints_3d_gt = config.opt.scale_keypoints_3d_gt if hasattr(config.opt, "scale_keypoints_3d_gt") else scale_keypoints_3d

                # force ground truth keypoints to fit config kind
                keypoints_gt_original = keypoints_3d_gt.clone()

                if keypoints_3d_gt.shape[1] != n_joints : #and transfer_cmu_h36m:
                    print(
                        f"[Warning] Possibly due to different pretrained model type, ground truth has {keypoints_3d_gt.shape[1]} keypoints while predicted has {n_joints} keypoints"
                    )
                    keypoints_3d_gt = keypoints_3d_gt[:, :n_joints, :]
                    keypoints_3d_binary_validity_gt = keypoints_3d_binary_validity_gt[
                        :, :n_joints, :]

                # 1-view case
                # TODO: Totally remove for CMU dataset (which doesnt have pelvis-offset errors)?
                if n_views == 1:
                    print(f"[{train_eval_mode}, {epoch}, {iter_i}] {config.kind} 1-view case: batch {iter_i}, images {images_batch.shape}")

                    if config.kind == "human36m":
                        base_joint = 6
                    elif config.kind in ["coco", "cmu", "cmupanoptic"]:
                        base_joint = 11

                    keypoints_3d_gt_transformed = keypoints_3d_gt.clone()
                    keypoints_3d_gt_transformed[:, torch.arange(n_joints) != base_joint] -= keypoints_3d_gt_transformed[:, base_joint:base_joint + 1]
                    keypoints_3d_gt = keypoints_3d_gt_transformed
                    
                    keypoints_3d_pred_transformed = keypoints_3d_pred.clone()
                    keypoints_3d_pred_transformed[:, torch.arange(n_joints) != base_joint] -= keypoints_3d_pred_transformed[:, base_joint:base_joint + 1]
                    keypoints_3d_pred = keypoints_3d_pred_transformed

                # calculate loss
                if DEBUG:
                    print(f"[{train_eval_mode}, {epoch}, {iter_i}] Calculating loss... ", end="")

                total_loss = 0.0

                loss = criterion(
                    keypoints_3d_pred * scale_keypoints_3d, 
                    keypoints_3d_gt * scale_keypoints_3d, 
                    keypoints_3d_binary_validity_gt
                )
                total_loss += loss
                metric_dict[f'{config.opt.criterion}'].append(loss.item())

                # volumetric ce loss
                use_volumetric_ce_loss = config.opt.use_volumetric_ce_loss if hasattr(config.opt, "use_volumetric_ce_loss") else False
                if use_volumetric_ce_loss:
                    volumetric_ce_criterion = VolumetricCELoss()

                    loss = volumetric_ce_criterion(coord_volumes_pred, volumes_pred, keypoints_3d_gt, keypoints_3d_binary_validity_gt)
                    metric_dict['volumetric_ce_loss'].append(loss.item())

                    weight = config.opt.volumetric_ce_loss_weight if hasattr(config.opt, "volumetric_ce_loss_weight") else 1.0
                    total_loss += weight * loss

                metric_dict['total_loss'].append(total_loss.item())

                if DEBUG:
                    print("Done!")

                if is_train:
                    if DEBUG:
                        print(f"[{train_eval_mode}, {epoch}, {iter_i}] Backpropragating... ", end="")

                    opt.zero_grad()
                    total_loss.backward()

                    if hasattr(config.opt, "grad_clip"):
                        torch.nn.utils.clip_grad_norm_(model.parameters(), config.opt.grad_clip / config.opt.lr)

                    metric_dict['grad_norm_times_lr'].append(config.opt.lr * misc.calc_gradient_norm(filter(lambda x: x[1].requires_grad, model.named_parameters())))

                    opt.step()

                    if DEBUG:
                        print("Done!")

                # calculate metrics
                if DEBUG:
                    print(f"[{train_eval_mode}, {epoch}, {iter_i}] Calculating metrics... ", end="")

                l2 = KeypointsL2Loss()(
                    keypoints_3d_pred * scale_keypoints_3d,
                    keypoints_3d_gt * scale_keypoints_3d,
                    keypoints_3d_binary_validity_gt
                )
                metric_dict['l2'].append(l2.item())

                # base point l2
                if base_points_pred is not None:
                    if DEBUG:
                        print(f"\n\tCalculating base point metric...", end="")

                    base_point_l2_list = []
                    for batch_i in range(batch_size):
                        base_point_pred = base_points_pred[batch_i]

                        if config.model.kind == "coco":
                            base_point_gt = (keypoints_3d_gt[batch_i, 11, :3] + keypoints_3d_gt[batch_i, 12, :3]) / 2
                        elif config.model.kind == "mpii":
                            base_point_gt = keypoints_3d_gt[batch_i, 6, :3]
                        elif config.model.kind == "cmu":
                            base_point_gt = keypoints_3d_gt[batch_i, 2, :3]

                        base_point_l2_list.append(torch.sqrt(torch.sum((base_point_pred * scale_keypoints_3d - base_point_gt * scale_keypoints_3d) ** 2)).item())

                    base_point_l2 = 0.0 if len(base_point_l2_list) == 0 else np.mean(base_point_l2_list)
                    metric_dict['base_point_l2'].append(base_point_l2)

                    if DEBUG:
                        print("Done!")

                if DEBUG:
                    print("Done!")

                # save answers for evalulation
                if not is_train:
                    results['keypoints_3d'].append(keypoints_3d_pred.detach().cpu().numpy())
                    results['indexes'].append(batch['indexes'])
                    
                    if save_extra_data:
                        extra_data['images'].append(batch['images'])
                        extra_data['detections'].append(batch['detections'])
                        extra_data['keypoints_3d_gt'].append(batch['keypoints_3d'])
                        extra_data['cameras'].append(batch['cameras'])

                # plot visualization
                # NOTE: transfer_cmu_h36m has a visualisation error, and connectivity dict needs to be h36m
                if master:
                    if n_iters_total % config.vis_freq == 0:# or total_l2.item() > 500.0:
                        vis_kind = config.kind if hasattr(config, "kind") else "coco"
                        pred_kind = config.pred_kind if hasattr(config, "pred_kind") else None

                        if transfer_cmu_h36m and pred_kind is None:
                            pred_kind = "human36m"
                        
                        # NOTE: Because of transfering, using original gt instead of truncated ones 
                        for batch_i in range(min(batch_size, config.vis_n_elements)):
                            keypoints_vis = vis.visualize_batch(
                                images_batch, heatmaps_pred, keypoints_2d_pred, proj_matricies_batch,
                                keypoints_gt_original, keypoints_3d_pred,
                                kind=vis_kind,
                                cuboids_batch=cuboids_pred,
                                confidences_batch=confidences_pred,
                                batch_index=batch_i, size=5,
                                max_n_cols=10,
                                pred_kind=pred_kind
                            )
                            writer.add_image(f"{name}/keypoints_vis/{batch_i}", keypoints_vis.transpose(2, 0, 1), global_step=n_iters_total)

                            heatmaps_vis = vis.visualize_heatmaps(
                                images_batch, heatmaps_pred,
                                kind=pred_kind,
                                batch_index=batch_i, size=5,
                                max_n_rows=10, max_n_cols=10
                            )
                            writer.add_image(f"{name}/heatmaps/{batch_i}", heatmaps_vis.transpose(2, 0, 1), global_step=n_iters_total)

                            if model_type == "vol":
                                volumes_vis = vis.visualize_volumes(
                                    images_batch, volumes_pred, proj_matricies_batch,
                                    kind=pred_kind,
                                    cuboids_batch=cuboids_pred,
                                    batch_index=batch_i, size=5,
                                    max_n_rows=1, max_n_cols=16
                                )
                                writer.add_image(f"{name}/volumes/{batch_i}", volumes_vis.transpose(2, 0, 1), global_step=n_iters_total)

                    # dump weights to tensoboard
                    if n_iters_total % config.vis_freq == 0:
                        for p_name, p in model.named_parameters():
                            try:
                                writer.add_histogram(p_name, p.clone().cpu().data.numpy(), n_iters_total)
                            except ValueError as e:
                                print(e)
                                print(p_name, p)
                                exit()

                    # dump to tensorboard per-iter loss/metric stats
                    if is_train:
                        for title, value in metric_dict.items():
                            writer.add_scalar(f"{name}/{title}", value[-1], n_iters_total)

                    # measure elapsed time
                    batch_time = time.time() - end
                    end = time.time()

                    # dump to tensorboard per-iter time stats
                    writer.add_scalar(f"{name}/batch_time", batch_time, n_iters_total)
                    writer.add_scalar(f"{name}/data_time", data_time, n_iters_total)

                    # dump to tensorboard per-iter stats about sizes
                    writer.add_scalar(f"{name}/batch_size", batch_size, n_iters_total)
                    writer.add_scalar(f"{name}/n_views", n_views, n_iters_total)

                    n_iters_total += 1

            if DEBUG:
                print(f"Training of epoch {epoch}, batch {iter_i} complete!")

    # calculate evaluation metrics
    if master:
        if not is_train:
            if DEBUG:
                print("Calculating evaluation metrics... ", end="")

            results['keypoints_3d'] = np.concatenate(
                results['keypoints_3d'], axis=0)
            results['indexes'] = np.concatenate(results['indexes'])

            try:
                scalar_metric, full_metric = dataloader.dataset.evaluate(results['keypoints_3d'])
            except Exception as e:
                print("Failed to evaluate. Reason: ", e)
                scalar_metric, full_metric = 0.0, {}

            metric_dict['dataset_metric'].append(scalar_metric)

            checkpoint_dir = os.path.join(experiment_dir, "checkpoints", "{:04}".format(epoch))
            os.makedirs(checkpoint_dir, exist_ok=True)
            
            if DEBUG:
                print("Calculated!")

            # dump results
            with open(os.path.join(checkpoint_dir, "results.pkl"), 'wb') as fout:
                if DEBUG:
                    print(f"Dumping results to {checkpoint_dir}/results.pkl... ", end="")
                pickle.dump(results, fout, protocol=4)
                if DEBUG:
                    print("Dumped!")

            # dump extra data as pkl file if need to reconstruct anything
            if save_extra_data: 
                with open(os.path.join(checkpoint_dir, "extra_data.pkl"), 'wb') as fout:
                    if DEBUG:
                        print(f"Dumping extra data to {checkpoint_dir}/extra_data.pkl... ", end="")

                    pickle.dump(extra_data, fout, protocol=4)
                    
                    if DEBUG:
                        print("Dumped!")

            # dump full metric
            with open(os.path.join(checkpoint_dir, "metric.json".format(epoch)), 'w') as fout:
                if DEBUG:
                    print(f"Dumping metric to {checkpoint_dir}/metric.json... ", end="")
                
                json.dump(full_metric, fout, indent=4, sort_keys=True)
                
                if DEBUG:
                    print("Dumped!")

        # dump to tensorboard per-epoch stats
        for title, value in metric_dict.items():
            writer.add_scalar(f"{name}/{title}_epoch", np.mean(value), epoch)

    print(f"Epoch {epoch} {train_eval_mode} complete!")

    return n_iters_total
def one_epoch(model,
              criterion,
              opt,
              config,
              dataloader,
              device,
              epoch,
              n_iters_total=0,
              caption='',
              master=False,
              experiment_dir=None,
              writer=None):
    name = "train" if is_train else "val"
    model_type = config.model.name

    model.eval()

    metric_dict = defaultdict(list)

    results = defaultdict(list)

    save_extra_data = config.save_extra_data if hasattr(
        config, "save_extra_data") else False

    if save_extra_data:
        extra_data = defaultdict(list)

    transfer_cmu_h36m = config.model.transfer_cmu_to_human36m if hasattr(
        config.model, "transfer_cmu_to_human36m") else False

    print("Transfer CMU to H36M: ", transfer_cmu_h36m)
    print("Using GT Pelvis position: ", config.model.use_gt_pelvis)
    print("Using cameras: ", dataloader.dataset.choose_cameras)
    print("Debug Mode: ", DEBUG)
    train_eval_mode = "Demo"

    # no gradients as we are only testing/evaluating
    with torch.no_grad():
        end = time.time()

        iterator = enumerate(dataloader)

        if not is_train and config.opt.n_iters_per_epoch_val is not None:
            iterator = islice(iterator, config.opt.n_iters_per_epoch_val)
        '''
        Data breakdown:
        - For each of the (max) 31 cameras in CMU dataset:
            - OpenCV Image: Numpy array [Note: likely cropped to smaller shape]
            - BBOX Detection for the image: (left, top, right, bottom) tuple
            - Camera: `Camera` object from `multiview.py`
        - Index: int
        - Keypoints (gt): NP Array, (17, 4)
        - Keypoints (pred): NP Array, (17, 4) [Note: may not be there]
        '''
        ignore_batch = []

        for iter_i, batch in iterator:
            with autograd.detect_anomaly():
                # measure data loading time
                data_time = time.time() - end

                if batch is None:
                    print(
                        f"[{train_eval_mode}, {epoch}] Found None batch: {iter_i}"
                    )
                    continue

                if DEBUG:
                    print(f"{train_eval_mode} batch {iter_i}...")
                    print(
                        f"[{train_eval_mode}, {epoch}, {iter_i}] Preparing batch... ",
                        end="")

                images_batch, keypoints_3d_gt, keypoints_3d_validity_gt, proj_matricies_batch = dataset_utils.prepare_batch(
                    batch, device, config)

                if DEBUG:
                    print("Prepared!")

                if DEBUG:
                    print(
                        f"[{train_eval_mode}, {epoch}, {iter_i}] Running {model_type} model... ",
                        end="")

                keypoints_2d_pred, cuboids_pred, base_points_pred = None, None, None
                if model_type == "alg" or model_type == "ransac":
                    keypoints_3d_pred, keypoints_2d_pred, heatmaps_pred, confidences_pred = model(
                        images_batch, proj_matricies_batch, batch)
                elif model_type == "vol":
                    keypoints_3d_pred, heatmaps_pred, volumes_pred, confidences_pred, cuboids_pred, coord_volumes_pred, base_points_pred = model(
                        images_batch, proj_matricies_batch, batch)
                else:
                    raise NotImplementedError(
                        f"Unknown model type {model_type}")

                if DEBUG:
                    print("Done!")

                # batch shape[2] is likely to be the number of channels
                # n_views is also the number of cameras being used in this batch
                batch_size, n_views, image_shape = images_batch.shape[
                    0], images_batch.shape[1], tuple(images_batch.shape[3:])
                n_joints = keypoints_3d_pred.shape[1]

                # Due to differences in model used, it may be possible that the gt and pred keypoints have different scales
                # Set this difference in scaling in the config.yaml file
                scale_keypoints_3d = config.opt.scale_keypoints_3d if hasattr(
                    config.opt, "scale_keypoints_3d") else 1.0

                # force ground truth keypoints to fit config kind
                keypoints_gt_original = keypoints_3d_gt.clone()

                # 1-view case
                # TODO: Totally remove for CMU dataset (which doesnt have pelvis-offset errors)?
                if n_views == 1:
                    print(
                        f"[{train_eval_mode}, {epoch}, {iter_i}] {config.kind} 1-view case: batch {iter_i}, images {images_batch.shape}"
                    )

                    if config.kind == "human36m":
                        base_joint = 6
                    elif config.kind in ["coco", "cmu", "cmupanoptic"]:
                        base_joint = 11

                    keypoints_3d_pred_transformed = keypoints_3d_pred.clone()
                    keypoints_3d_pred_transformed[:,
                                                  torch.arange(n_joints) !=
                                                  base_joint] -= keypoints_3d_pred_transformed[:,
                                                                                               base_joint:
                                                                                               base_joint
                                                                                               +
                                                                                               1]
                    keypoints_3d_pred = keypoints_3d_pred_transformed

                if DEBUG:
                    print("Done!")

                # calculate metrics
                if DEBUG:
                    print(
                        f"[{train_eval_mode}, {epoch}, {iter_i}] Calculating metrics... ",
                        end="")

                # save answers for evalulation
                if not is_train:
                    results['keypoints_3d'].append(
                        keypoints_3d_pred.detach().cpu().numpy())
                    results['indexes'].append(batch['indexes'])

                    if save_extra_data:
                        extra_data['images'].append(batch['images'])
                        extra_data['detections'].append(batch['detections'])
                        extra_data['cameras'].append(batch['cameras'])

                # plot visualization
                # NOTE: transfer_cmu_h36m has a visualisation error, and connectivity dict needs to be h36m
                if master:
                    if n_iters_total % config.vis_freq == 0:  # or total_l2.item() > 500.0:
                        vis_kind = config.kind if hasattr(config,
                                                          "kind") else "coco"
                        pred_kind = config.pred_kind if hasattr(
                            config, "pred_kind") else None

                        if transfer_cmu_h36m and pred_kind is None:
                            pred_kind = "human36m"

                        # NOTE: Because of transfering, using original gt instead of truncated ones
                        for batch_i in range(
                                min(batch_size, config.vis_n_elements)):
                            keypoints_vis = vis.visualize_batch(
                                images_batch,
                                heatmaps_pred,
                                keypoints_2d_pred,
                                proj_matricies_batch,
                                None,
                                keypoints_3d_pred,
                                kind=vis_kind,
                                cuboids_batch=cuboids_pred,
                                confidences_batch=confidences_pred,
                                batch_index=batch_i,
                                size=5,
                                max_n_cols=10,
                                pred_kind=pred_kind)
                            writer.add_image(f"{name}/keypoints_vis/{batch_i}",
                                             keypoints_vis.transpose(2, 0, 1),
                                             global_step=n_iters_total)

                            heatmaps_vis = vis.visualize_heatmaps(
                                images_batch,
                                heatmaps_pred,
                                kind=pred_kind,
                                batch_index=batch_i,
                                size=5,
                                max_n_rows=10,
                                max_n_cols=10)
                            writer.add_image(f"{name}/heatmaps/{batch_i}",
                                             heatmaps_vis.transpose(2, 0, 1),
                                             global_step=n_iters_total)

                            if model_type == "vol":
                                volumes_vis = vis.visualize_volumes(
                                    images_batch,
                                    volumes_pred,
                                    proj_matricies_batch,
                                    kind=pred_kind,
                                    cuboids_batch=cuboids_pred,
                                    batch_index=batch_i,
                                    size=5,
                                    max_n_rows=1,
                                    max_n_cols=16)
                                writer.add_image(f"{name}/volumes/{batch_i}",
                                                 volumes_vis.transpose(
                                                     2, 0, 1),
                                                 global_step=n_iters_total)

                    # dump weights to tensoboard
                    if n_iters_total % config.vis_freq == 0:
                        for p_name, p in model.named_parameters():
                            try:
                                writer.add_histogram(
                                    p_name,
                                    p.clone().cpu().data.numpy(),
                                    n_iters_total)
                            except ValueError as e:
                                print(e)
                                print(p_name, p)
                                exit()

                    # measure elapsed time
                    batch_time = time.time() - end
                    end = time.time()

                    # dump to tensorboard per-iter time stats
                    writer.add_scalar(f"{name}/batch_time", batch_time,
                                      n_iters_total)
                    writer.add_scalar(f"{name}/data_time", data_time,
                                      n_iters_total)

                    # dump to tensorboard per-iter stats about sizes
                    writer.add_scalar(f"{name}/batch_size", batch_size,
                                      n_iters_total)
                    writer.add_scalar(f"{name}/n_views", n_views,
                                      n_iters_total)

                    n_iters_total += 1

            if DEBUG:
                print(f"Training of epoch {epoch}, batch {iter_i} complete!")

    print(f"Epoch {epoch} {train_eval_mode} complete!")

    return n_iters_total
    def inferHuman36Data(self, batch, model_type, device, config, randomize_n_views,
                                        min_n_views,
                                        max_n_views):
        """
        For batch inferences 
        """
        outputBatch = {}
        inputBatch = {}
        collatFunction = dataset_utils.make_collate_fn(randomize_n_views,
                                        min_n_views,
                                        max_n_views)
        batch = collatFunction(batch)
        images_batch, keypoints_3d_gt, keypoints_3d_validity_gt, proj_matricies_batch  = dataset_utils.prepare_batch(batch, device, config)
        #print(proj_matricies_batch,proj_matricies_batch.shape,len(batch),images_batch.shape)

        keypoints_2d_pred, cuboids_pred, base_points_pred, volumes_pred, coord_volumes_pred = None, None, None, None, None
        if model_type == "alg" or model_type == "ransac":
            keypoints_3d_pred, keypoints_2d_pred, heatmaps_pred, confidences_pred = self.model(images_batch, proj_matricies_batch, batch)
        elif model_type == "vol":
            keypoints_3d_pred, heatmaps_pred, volumes_pred, confidences_pred, cuboids_pred, coord_volumes_pred, base_points_pred = self.model(images_batch, proj_matricies_batch, batch)

        outputBatch["keypoints_3d_pred"] = keypoints_3d_pred
        outputBatch["heatmaps_pred"] = heatmaps_pred
        outputBatch["volumes_pred"] = volumes_pred
        outputBatch["confidences_pred"] = confidences_pred
        outputBatch["cuboids_pred"] = confidences_pred
        outputBatch["coord_volumes_pred"] = coord_volumes_pred
        outputBatch["base_points_pred"] = base_points_pred

        inputBatch["images_batch"] = images_batch
        inputBatch["proj_matricies_batch"] = proj_matricies_batch
        return outputBatch, inputBatch
def one_epoch_full(model,
                   criterion,
                   opt_dict,
                   config,
                   dataloader,
                   device,
                   epoch,
                   n_iters_total=0,
                   is_train=True,
                   lr=None,
                   mean_and_std=None,
                   limb_length=None,
                   caption='',
                   master=False,
                   experiment_dir=None,
                   writer=None,
                   whole_val_dataloader=None,
                   dist_size=None):
    name = "train" if is_train else "val"
    model_type = config.model.name

    if is_train:
        if config.model.backbone.fix_weights:
            model.module.backbone.eval()
            if config.model.volume_net.use_feature_v2v:
                model.module.process_features.train()
            model.module.volume_net.train()
        else:
            model.train()
    else:
        model.eval()

    metric_dict = defaultdict(list)

    results = defaultdict(list)

    # used to turn on/off gradients
    grad_context = torch.autograd.enable_grad if is_train else torch.no_grad
    with grad_context():
        end = time.time()

        if master:
            if is_train and config.train.n_iters_per_epoch is not None:
                pbar = tqdm(
                    total=min(config.train.n_iters_per_epoch, len(dataloader)))
            else:
                pbar = tqdm(total=len(dataloader))

        iterator = enumerate(dataloader)
        if is_train and config.train.n_iters_per_epoch is not None:
            iterator = islice(iterator, config.train.n_iters_per_epoch)

        for iter_i, batch in iterator:
            # measure data loading time
            data_time = time.time() - end

            if batch is None:
                print("Found None batch")
                continue

            images_batch, keypoints_3d_gt, keypoints_validity_gt, proj_matricies_batch = dataset_utils.prepare_batch(
                batch, device, config)

            keypoints_2d_pred, cuboids_pred, base_points_pred = None, None, None
            if model_type == "vol":
                voxel_keypoints_3d_pred, keypoints_3d_pred, heatmaps_pred,\
                    volumes_pred, ga_mask_gt, atten_global, confidences_pred, cuboids_pred, coord_volumes_pred, base_points_pred =\
                    model(images_batch, proj_matricies_batch, batch, keypoints_3d_gt)

            batch_size, n_views, image_shape = images_batch.shape[
                0], images_batch.shape[1], tuple(images_batch.shape[3:])
            n_joints = keypoints_3d_pred.shape[1]

            keypoints_binary_validity_gt = (keypoints_validity_gt > 0.0).type(
                torch.float32)

            scale_keypoints_3d = config.loss.scale_keypoints_3d

            # calculate loss
            total_loss = 0.0
            loss = criterion(keypoints_3d_pred * scale_keypoints_3d,
                             keypoints_3d_gt * scale_keypoints_3d,
                             keypoints_binary_validity_gt)
            total_loss += loss
            metric_dict[config.loss.criterion].append(loss.item())

            # volumetric ce loss
            if config.loss.use_volumetric_ce_loss:
                volumetric_ce_criterion = VolumetricCELoss()

                loss = volumetric_ce_criterion(coord_volumes_pred,
                                               volumes_pred, keypoints_3d_gt,
                                               keypoints_binary_validity_gt)
                metric_dict['volumetric_ce_loss'].append(loss.item())

                total_loss += config.loss.volumetric_ce_loss_weight * loss

            # global attention (3D heatmap) loss
            if config.loss.use_global_attention_loss:
                loss = nn.MSELoss(reduction='mean')(ga_mask_gt, atten_global)
                metric_dict['global_attention_loss'].append(loss.item())
                total_loss += config.loss.global_attention_loss_weight * loss

            metric_dict['total_loss'].append(total_loss.item())
            metric_dict['limb_length_error'].append(LimbLengthError()(
                keypoints_3d_pred.detach(), keypoints_3d_gt))

            if is_train:
                if not torch.isnan(total_loss):
                    for key in opt_dict.keys():
                        opt_dict[key].zero_grad()
                    total_loss.backward()

                    if config.loss.grad_clip:
                        torch.nn.utils.clip_grad_norm_(
                            model.parameters(),
                            config.loss.grad_clip / config.train.volume_net_lr)

                    metric_dict['grad_norm_times_volume_net_lr'].append(
                        config.train.volume_net_lr * misc.calc_gradient_norm(
                            filter(lambda x: x[1].requires_grad,
                                   model.named_parameters())))
                    if lr is not None:
                        for key in lr.keys():
                            metric_dict['lr_{}'.format(key)].append(lr[key])

                    for key in opt_dict.keys():
                        opt_dict[key].step()

            # calculate metrics
            l2 = KeypointsL2Loss()(keypoints_3d_pred * scale_keypoints_3d,
                                   keypoints_3d_gt * scale_keypoints_3d,
                                   keypoints_binary_validity_gt)
            metric_dict['l2'].append(l2.item())

            # base point l2
            if base_points_pred is not None:
                base_point_l2_list = []
                for batch_i in range(batch_size):
                    base_point_pred = base_points_pred[batch_i]

                    if config.model.kind == "coco":
                        base_point_gt = (keypoints_3d_gt[batch_i, 11, :3] +
                                         keypoints_3d[batch_i, 12, :3]) / 2
                    elif config.model.kind == "mpii":
                        base_point_gt = keypoints_3d_gt[batch_i, 6, :3]

                    base_point_l2_list.append(
                        torch.sqrt(
                            torch.sum((base_point_pred * scale_keypoints_3d -
                                       base_point_gt *
                                       scale_keypoints_3d)**2)).item())

                base_point_l2 = 0.0 if len(
                    base_point_l2_list) == 0 else np.mean(base_point_l2_list)
                metric_dict['base_point_l2'].append(base_point_l2)

            # save answers for evalulation
            if not is_train:
                results['keypoints_gt'].append(
                    keypoints_3d_gt.detach().cpu().numpy())  # (b, 17, 3)
                results['keypoints_3d'].append(
                    keypoints_3d_pred.detach().cpu().numpy())  # (b, 17, 3)
                results['proj_matricies_batch'].append(
                    proj_matricies_batch.detach().cpu().numpy(
                    ))  #(b, n_view, 3,4)
                results['indexes'].append(batch['indexes'])

            # plot visualization
            if master:
                if config.batch_output:
                    if n_iters_total % config.vis_freq == 0:  # or total_l2.item() > 500.0:
                        sample_i = iter_i * config.vis_freq + n_iters_total
                        vis_kind = config.kind
                        if config.dataset.transfer_cmu_to_human36m:
                            vis_kind = "coco"

                        for batch_i in range(
                                min(batch_size, config.vis_n_elements)):
                            keypoints_vis = vis.visualize_batch(
                                images_batch,
                                heatmaps_pred,
                                keypoints_2d_pred,
                                proj_matricies_batch,
                                keypoints_3d_gt,
                                keypoints_3d_pred,
                                kind=vis_kind,
                                cuboids_batch=cuboids_pred,
                                confidences_batch=confidences_pred,
                                batch_index=batch_i,
                                size=5,
                                max_n_cols=10)
                            writer.add_image("{}/keypoints_vis/{}".format(
                                name, batch_i),
                                             keypoints_vis.transpose(2, 0, 1),
                                             global_step=n_iters_total)

                            heatmaps_vis = vis.visualize_heatmaps(
                                images_batch,
                                heatmaps_pred,
                                kind=vis_kind,
                                batch_index=batch_i,
                                size=5,
                                max_n_rows=10,
                                max_n_cols=18)
                            writer.add_image("{}/heatmaps/{}".format(
                                name, batch_i),
                                             heatmaps_vis.transpose(2, 0, 1),
                                             global_step=n_iters_total)

                            if model_type == "vol":
                                volumes_vis = vis.visualize_volumes(
                                    images_batch,
                                    volumes_pred,
                                    proj_matricies_batch,
                                    kind=vis_kind,
                                    cuboids_batch=cuboids_pred,
                                    batch_index=batch_i,
                                    size=5,
                                    max_n_rows=1,
                                    max_n_cols=18)
                                writer.add_image(
                                    "{}/volumes/{}".format(name, batch_i),
                                    volumes_vis.transpose(2, 0, 1),
                                    global_step=n_iters_total)

                    # dump weights to tensoboard
                    if n_iters_total % config.vis_freq == 0:
                        for p_name, p in model.named_parameters():
                            try:
                                writer.add_histogram(
                                    p_name,
                                    p.clone().cpu().data.numpy(),
                                    n_iters_total)
                            except ValueError as e:
                                print(e)
                                print(p_name, p)
                                exit()

                    # dump to tensorboard per-iter loss/metric stats
                    if is_train:
                        for title, value in metric_dict.items():
                            writer.add_scalar("{}/{}".format(name, title),
                                              value[-1], n_iters_total)

                    # measure elapsed time
                    batch_time = time.time() - end
                    end = time.time()

                    # dump to tensorboard per-iter time stats
                    writer.add_scalar("{}/batch_time".format(name), batch_time,
                                      n_iters_total)
                    writer.add_scalar("{}/data_time".format(name), data_time,
                                      n_iters_total)

                    # dump to tensorboard per-iter stats about sizes
                    writer.add_scalar("{}/batch_size".format(name), batch_size,
                                      n_iters_total)
                    writer.add_scalar("{}/n_views".format(name), n_views,
                                      n_iters_total)

                n_iters_total += 1
                pbar.update(1)

    # calculate evaluation metrics
    if not is_train:
        if dist_size is not None:
            term_list = [
                'keypoints_gt', 'keypoints_3d', 'proj_matricies_batch',
                'indexes'
            ]
            for term in term_list:
                results[term] = np.concatenate(results[term])
                buffer = [
                    torch.zeros(dist_size[-1],
                                *results[term].shape[1:]).cuda()
                    for i in range(len(dist_size))
                ]
                scatter_tensor = torch.zeros_like(buffer[0])
                scatter_tensor[:results[term].shape[0]] = torch.tensor(
                    results[term]).cuda()
                torch.distributed.all_gather(buffer, scatter_tensor)
                results[term] = torch.cat(
                    [tensor[:n] for tensor, n in zip(buffer, dist_size)],
                    dim=0).cpu().numpy()

    if master:
        if not is_train:
            try:
                if dist_size is None:
                    print('evaluating....')
                    scalar_metric, full_metric = dataloader.dataset.evaluate(
                        results['keypoints_gt'], results['keypoints_3d'],
                        results['proj_matricies_batch'], config)
                else:
                    scalar_metric, full_metric = whole_val_dataloader.dataset.evaluate(
                        results['keypoints_gt'], results['keypoints_3d'],
                        results['proj_matricies_batch'], config)
            except Exception as e:
                print("Failed to evaluate. Reason: ", e)
                scalar_metric, full_metric = 0.0, {}

            metric_dict['dataset_metric'].append(scalar_metric)
            metric_dict['limb_length_error'] = [
                LimbLengthError()(results['keypoints_3d'],
                                  results['keypoints_gt'])
            ]

            checkpoint_dir = os.path.join(experiment_dir, "checkpoints",
                                          "{:04}".format(epoch))
            os.makedirs(checkpoint_dir, exist_ok=True)

            # dump results
            with open(os.path.join(checkpoint_dir, "results.pkl"),
                      'wb') as fout:
                pickle.dump(results, fout)

            # dump full metric
            with open(
                    os.path.join(checkpoint_dir, "metric.json".format(epoch)),
                    'w') as fout:
                json.dump(full_metric, fout, indent=4, sort_keys=True)

        # dump to tensorboard per-epoch stats
        for title, value in metric_dict.items():
            writer.add_scalar("{}/{}_epoch".format(name, title),
                              np.mean(value), epoch)

    return n_iters_total