def compute_separated_s_r_pckh(hg, img_list, c_list, s_list, r_list,
                               grnd_pts_list, grnd_heatmap_list,
                               normalizer_list):
    assert len(img_list) == 2
    pckh_list = []
    for k in range(0, 2):
        img_var = torch.autograd.Variable(img_list[k])
        out_reg = hg(img_var)
        tmp_pckh = Evaluation.per_person_pckh(out_reg[-1].data.cpu(),
                                              grnd_heatmap_list[k], c_list[k],
                                              s_list[k], [64, 64],
                                              grnd_pts_list[k],
                                              normalizer_list[k], r_list[k])
        pckh_list.append(tmp_pckh)

    return pckh_list
def collect_data(data_loader, dataset, hg, save_path):
    rot_num = len(dataset.rotation_means)
    # rotation_num = len(dataset.rotation_means)
    print 'rot_num: ', rot_num
    # print 'rotation_num: ', rotation_num
    # scale_1_idx = scale_num / 2
    # print 'scale_1_idx', scale_1_idx
    grnd_distri_list = []
    # index_list = []
    counter = 0
    hg.eval()
    # multipliers = np.zeros(7)
    # scale_factors = np.arange(0.7, 1.4, 0.1)
    # for i, s in enumerate(scale_factors):
    #     multipliers[i] = s
    # print 'multipliers: ', multipliers
    # exit()
    for i, (img_list, heatmap_list, center_list, scale_list, rot_list,
            grnd_pts_list, normalizer_list, rot_idx, img_index, pts_aug_list)\
            in enumerate(data_loader):
        print '%d/%d' % (i, len(data_loader))
        # pts_aug_back = Evaluation.transform_preds(pts_aug_list[0][0]+1, center_list[0][0],
        #                         scale_list[0][0], [64, 64], rot_list[0][0])
        # print 'grnd_pts: ', grnd_pts_list[0][0]
        # print 'pts_aug_back: ', pts_aug_back
        # exit()
        # print 'rot_idx size ', rot_idx.size()
        # exit()
        # print img_index
        # for j in range(0, len(img_list)):
        #     # print img_list[j].size()
        #     img_show = imutils.im_to_numpy(img_list[j][0])*255
        #     img_show = Image.fromarray(img_show.astype('uint8'), 'RGB')
        #     img_show.save('debug-images/%d.jpg' % j)
        # exit()
        # index_list.append(img_index)
        # print img_index
        # print(len(img_list))
        # print(img_list[0].size())
        # print img_list[0][0, 0, 0, 0], img_list[1][0, 0, 0, 0]
        # print img_list[0][1, 0, 0, 0], img_list[1][1, 0, 0, 0]
        # exit()
        val_img_index = torch.arange(counter, counter + len(img_index)).long()
        assert ((val_img_index - img_index).sum() == 0)
        counter += len(img_index)
        # print scale_ind.size()
        # print scale_ind
        # exit()
        # print rotation_ind.size()
        # print rotation_ind
        # exit()
        unique_num = rot_idx.size(0)
        assert rot_idx.size(1) == rot_num
        # print 'unique_num: ', unique_num
        img = torch.cat(img_list, dim=0)
        # print 'img size: ', img.size()
        # print img[0, 0, 0, 0], img[2, 0, 0, 0]
        # print img[1, 0, 0, 0], img[3, 0, 0, 0]
        # exit()
        pts_aug = torch.cat(pts_aug_list, dim=0)
        heatmap = torch.cat(heatmap_list, dim=0)
        # print 'heatmap size: ', heatmap.size()
        center = torch.cat(center_list, dim=0)
        # print 'center size: ', center.size()
        # print 'center: ', center
        # print 'center list 0: ', center_list[0]
        # exit()
        scale = torch.cat(scale_list, dim=0)
        # print 'scale size: ', scale.size()
        rotation = torch.cat(rot_list, dim=0)
        grnd_pts = torch.cat(grnd_pts_list, dim=0)
        # print 'grnd_pts size: ', grnd_pts.size()
        normalizer = torch.cat(normalizer_list, dim=0)
        # print 'normalizer size: ', normalizer.size()
        # exit()
        # batch_size = img.size(0)
        # print 'batch_size: ', batch_size
        # output and loss
        img_var = torch.autograd.Variable(img, volatile=True)
        out_reg = hg(img_var)
        # loss = 0
        # for per_out in out_reg:
        #     # print 'hg', counter
        #     # counter += 1
        #     # print per_out.size()
        #     per_out = per_out.data.cpu()
        #     loss = loss + (per_out - heatmap) ** 2
        #     # loss = loss + tmp_loss.sum() / tmp_loss.numel()
        # # exit()
        # # print 'loss type: ', type(loss)
        # # print 'loss size: ', loss.size()
        #
        # elm_num = loss.numel() / batch_size
        # # print 'elm_num: ', elm_num
        # loss = loss.view(loss.size(0), -1).sum(1).div_(elm_num)
        # loss = loss.squeeze().numpy()
        # print 'loss: ', loss
        pckhs = Evaluation.per_person_pckh(out_reg[-1].data.cpu(), heatmap,
                                           center, scale, [64, 64], grnd_pts,
                                           normalizer, rotation)
        lost_pckhs = 1 - pckhs
        # pckhs = pckhs.numpy()
        # print 'pckhs: ', pckhs
        # print 'pred_counts shape: ', pred_counts.shape
        for j in range(0, unique_num):
            tmp_pckhs = lost_pckhs[j::unique_num]
            # tmp_loss = loss[j::unique_num]
            # print 'tmp_loss:', tmp_loss
            # print 'weighted tmp_loss:', tmp_loss * multipliers
            # print 'tmp_pckh:', tmp_pckhs
            # exit()
            assert (tmp_pckhs.size(0) == rot_num)
            if tmp_pckhs.sum() == 0:
                print 'tmp_pckh: ', tmp_pckhs
                print 'sum of tmp_pckh are zero. Setting equal probabilities ...'
                tmp_distri = torch.ones(tmp_pckhs.size(0)) / tmp_pckhs.size(0)
            elif (tmp_pckhs < 0).any():
                print 'tmp_pckh: ', tmp_pckhs
                print 'some of tmp_pckh is negative. error...'
                exit()
            else:
                tmp_distri = tmp_pckhs.clone()
                tmp_distri = tmp_distri / tmp_distri.sum()
            # print 'tmp_distri: ', tmp_distri
            grnd_distri_list.append(tmp_distri)
            with open(save_path, 'a+') as log_file:
                tmp_distri = tmp_distri.numpy()
                np.savetxt(log_file,
                           tmp_distri.reshape(1, tmp_distri.shape[0]),
                           fmt='%.2f')
            # assert grnd_scale < scale_num
            # grnd_scale = torch.LongTensor([grnd_scale])
            # grnd_scale_list.append(grnd_scale)

        # if i == 0:
        #     break

    return grnd_distri_list