Beispiel #1
0
  def process(self, images, return_time=False):
    with torch.no_grad():
      torch.cuda.synchronize()
      output = self.model(images)[-1]
      output['hm'] = output['hm'].sigmoid_()
      if self.opt.hm_hp and not self.opt.mse_loss:
        output['hm_hp'] = output['hm_hp'].sigmoid_()

      reg = output['reg'] if self.opt.reg_offset else None
      hm_hp = output['hm_hp'] if self.opt.hm_hp else None
      hp_offset = output['hp_offset'] if self.opt.reg_hp_offset else None
      torch.cuda.synchronize()
      forward_time = time.time()
      
      if self.opt.flip_test:
        output['hm'] = (output['hm'][0:1] + flip_tensor(output['hm'][1:2])) / 2
        output['wh'] = (output['wh'][0:1] + flip_tensor(output['wh'][1:2])) / 2
        output['hps'] = (output['hps'][0:1] + 
          flip_lr_off(output['hps'][1:2], self.flip_idx)) / 2
        hm_hp = (hm_hp[0:1] + flip_lr(hm_hp[1:2], self.flip_idx)) / 2 \
                if hm_hp is not None else None
        reg = reg[0:1] if reg is not None else None
        hp_offset = hp_offset[0:1] if hp_offset is not None else None
      
      dets = multi_pose_decode(
        output['hm'], output['wh'], output['hps'],
        reg=reg, hm_hp=hm_hp, hp_offset=hp_offset, K=self.opt.K)

    if return_time:
      return output, dets, forward_time
    else:
      return output, dets
    def process(self, images, return_time=False):
        with torch.no_grad():
            if False:
                torch.cuda.synchronize()
                output = self.model(images)[-1]
                output['hm'] = output['hm'].sigmoid_()
                if self.opt.hm_hp and not self.opt.mse_loss:
                    output['hm_hp'] = output['hm_hp'].sigmoid_()

                reg = output['reg'] if self.opt.reg_offset else None
                hm_hp = output['hm_hp'] if self.opt.hm_hp else None
                hp_offset = output[
                    'hp_offset'] if self.opt.reg_hp_offset else None
                torch.cuda.synchronize()
                forward_time = time.time()

                if self.opt.flip_test:
                    output['hm'] = (output['hm'][0:1] +
                                    flip_tensor(output['hm'][1:2])) / 2
                    output['wh'] = (output['wh'][0:1] +
                                    flip_tensor(output['wh'][1:2])) / 2
                    output['hps'] = (output['hps'][0:1] + flip_lr_off(
                        output['hps'][1:2], self.flip_idx)) / 2
                    hm_hp = (hm_hp[0:1] + flip_lr(hm_hp[1:2], self.flip_idx)) / 2 \
                            if hm_hp is not None else None
                    reg = reg[0:1] if reg is not None else None
                    hp_offset = hp_offset[
                        0:1] if hp_offset is not None else None

                dets = multi_pose_decode(output['hm'],
                                         output['wh'],
                                         output['hps'],
                                         reg=reg,
                                         hm_hp=hm_hp,
                                         hp_offset=hp_offset,
                                         K=self.opt.K)
            else:
                hm, wh, reg, hps, hm_hp, hp_offset = self.model(images)
                names = ['hm', 'wh', 'reg', 'hps', 'hm_hp', 'hp_offset']
                torch.onnx.export(self.model,
                                  images,
                                  "pose.onnx",
                                  opset_version=9,
                                  verbose=False,
                                  input_names=["input"],
                                  output_names=names)
                quit()

        if return_time:
            return output, dets, forward_time
        else:
            return output, dets
    def process(self, images, return_time=False):
        with torch.no_grad():
            torch.cuda.synchronize()

            output = self.model(images)[-1]

            output['hm'] = output['hm'].sigmoid_()
            if self.opt.hm_hp and not self.opt.mse_loss:
                output['hm_hp'] = output['hm_hp'].sigmoid_()
            reg = output['reg'] if self.opt.reg_offset else None
            hm_hp = output['hm_hp'] if self.opt.hm_hp else None
            hp_offset = output['hp_offset'] if self.opt.reg_hp_offset else None
            torch.cuda.synchronize()
            forward_time = time.time()

            if self.opt.flip_test:
                output['hm'] = (output['hm'][0:1] +
                                flip_tensor(output['hm'][1:2])) / 2
                output['wh'] = (output['wh'][0:1] +
                                flip_tensor(output['wh'][1:2])) / 2
                output['hps'] = (output['hps'][0:1] + flip_lr_off(
                    output['hps'][1:2], self.flip_idx)) / 2
                hm_hp = (hm_hp[0:1] + flip_lr(hm_hp[1:2], self.flip_idx)) / 2 \
                        if hm_hp is not None else None
                reg = reg[0:1] if reg is not None else None
                hp_offset = hp_offset[0:1] if hp_offset is not None else None
            '''
      output['hm']:  [b, 1, h, w]  "center point"
      output['wh']:  [b, 2, h, w]  "size"
      reg:           [b, 2, h, w]  "local offset"
      
      output['hps']: [b, 34, h, w]  is this "joint offset"?
      hm_hp:         [b, 17, h, w]  "heat map"
      hp_offset      [b, 2, h, w]  unknown
      where h, w = in_h / stride(down_ratio), in_w / stride(down_ratio)
      '''
            dets = multi_pose_decode(output['hm'],
                                     output['wh'],
                                     output['hps'],
                                     reg=reg,
                                     hm_hp=hm_hp,
                                     hp_offset=hp_offset,
                                     K=self.opt.K)

        if return_time:
            return output, dets, forward_time
        else:
            return output, dets
Beispiel #4
0
    def process(self, images, return_time=False):
        with torch.no_grad():
            torch.cuda.synchronize()
            outputs = self.model(images)
            #hm, wh, hps, reg, hm_hp, hp_offset = outputs
            hm, wh, hps, reg, hm_hp, hp_offset, seg_feat, seg = outputs

            hm = hm.sigmoid_()
            if self.cfg.LOSS.HM_HP and not self.cfg.LOSS.MSE_LOSS:
                hm_hp = hm_hp.sigmoid_()

            reg = reg if self.cfg.LOSS.REG_OFFSET else None
            hm_hp = hm_hp if self.cfg.LOSS.HM_HP else None
            hp_offset = hp_offset if self.cfg.LOSS.REG_HP_OFFSET else None
            torch.cuda.synchronize()
            forward_time = time.time()

            if self.cfg.TEST.FLIP_TEST:
                hm = (hm[0:1] + flip_tensor(hm[1:2])) / 2
                wh = (wh[0:1] + flip_tensor(wh[1:2])) / 2
                hps = (hps[0:1] + flip_lr_off(hps[1:2], self.flip_idx)) / 2
                hm_hp = (hm_hp[0:1] + flip_lr(hm_hp[1:2], self.flip_idx)) / 2 \
                        if hm_hp is not None else None
                reg = reg[0:1] if reg is not None else None
                hp_offset = hp_offset[0:1] if hp_offset is not None else None

            dets = whole_body_decode(hm,
                                     wh,
                                     hps,
                                     seg_feat=seg_feat,
                                     seg=seg,
                                     reg=reg,
                                     hm_hp=hm_hp,
                                     hp_offset=hp_offset,
                                     K=self.cfg.TEST.TOPK)

        if return_time:
            return outputs, dets, forward_time
        else:
            return outputs, dets
Beispiel #5
0
    def process(self, images, return_time=False):
        with torch.no_grad():
            torch.cuda.synchronize()
            if self.opt.zjb:
                all_out = self.model(images)
                output = all_out[-1]
                #gt
                output['hm_hp'] = self.gt[0]['hm_hp'].cuda()
                if self.opt.hm_hp and not self.opt.mse_loss:
                    output['hm_hp'] = output['hm_hp'].sigmoid_()

                hm_hp = output['hm_hp'] if self.opt.hm_hp else None
                hp_offset = output[
                    'hp_offset'] if self.opt.reg_hp_offset else None
                torch.cuda.synchronize()
                forward_time = time.time()

                #T-param

                # if self.opt.flip_test:
                #   output['hm'] = (output['hm'][0:1] + flip_tensor(output['hm'][1:2])) / 2
                #   output['wh'] = (output['wh'][0:1] + flip_tensor(output['wh'][1:2])) / 2
                #   output['hps'] = (output['hps'][0:1] +
                #     flip_lr_off(output['hps'][1:2], self.flip_idx)) / 2
                #   hm_hp = (hm_hp[0:1] + flip_lr(hm_hp[1:2], self.flip_idx)) / 2 \
                #           if hm_hp is not None else None
                #   reg = reg[0:1] if reg is not None else None
                #   hp_offset = hp_offset[0:1] if hp_offset is not None else None

                dets, center = multi_pose_decode_c(output,
                                                   output['hps'],
                                                   hm_hp=hm_hp,
                                                   hp_offset=hp_offset,
                                                   K=self.opt.K)
                dets = dets.data.cpu().numpy().reshape(2, -1, 8)
                center = center.data.cpu().numpy().reshape(2, -1, 4)
                dets = dets.reshape(1, -1, 8)
                center = center.reshape(1, -1, 4)
                detections = dets
                center_points = center
                classes = detections[..., -1]
                classes = classes[0]
                detections = detections[0]
                center_points = center_points[0]

                valid_ind = detections[:, 4] > -1
                valid_detections = detections[valid_ind]

                box_width = valid_detections[:, 2] - valid_detections[:, 0]
                box_height = valid_detections[:, 3] - valid_detections[:, 1]

                s_ind = (box_width * box_height <= 22500)
                l_ind = (box_width * box_height > 22500)

                s_detections = valid_detections[s_ind]
                l_detections = valid_detections[l_ind]

                s_left_x = (2 * s_detections[:, 0] + s_detections[:, 2]) / 3
                s_right_x = (s_detections[:, 0] + 2 * s_detections[:, 2]) / 3
                s_top_y = (2 * s_detections[:, 1] + s_detections[:, 3]) / 3
                s_bottom_y = (s_detections[:, 1] + 2 * s_detections[:, 3]) / 3

                s_temp_score = copy.copy(s_detections[:, 4])
                s_detections[:, 4] = -1

                center_x = center_points[:, 0][:, np.newaxis]
                center_y = center_points[:, 1][:, np.newaxis]
                s_left_x = s_left_x[np.newaxis, :]
                s_right_x = s_right_x[np.newaxis, :]
                s_top_y = s_top_y[np.newaxis, :]
                s_bottom_y = s_bottom_y[np.newaxis, :]

                ind_lx = (center_x - s_left_x) > 0
                ind_rx = (center_x - s_right_x) < 0
                ind_ty = (center_y - s_top_y) > 0
                ind_by = (center_y - s_bottom_y) < 0
                ind_cls = (center_points[:, 2][:, np.newaxis] -
                           s_detections[:, -1][np.newaxis, :]) == 0
                ind_s_new_score = np.max(
                    ((ind_lx + 0) & (ind_rx + 0) & (ind_ty + 0) &
                     (ind_by + 0) & (ind_cls + 0)),
                    axis=0) == 1
                index_s_new_score = np.argmax(
                    ((ind_lx + 0) & (ind_rx + 0) & (ind_ty + 0) & (ind_by + 0)
                     & (ind_cls + 0))[:, ind_s_new_score],
                    axis=0)
                s_detections[:, 4][ind_s_new_score] = (
                    s_temp_score[ind_s_new_score] * 2 +
                    center_points[index_s_new_score, 3]) / 3

                l_left_x = (3 * l_detections[:, 0] +
                            2 * l_detections[:, 2]) / 5
                l_right_x = (2 * l_detections[:, 0] +
                             3 * l_detections[:, 2]) / 5
                l_top_y = (3 * l_detections[:, 1] + 2 * l_detections[:, 3]) / 5
                l_bottom_y = (2 * l_detections[:, 1] +
                              3 * l_detections[:, 3]) / 5

                l_temp_score = copy.copy(l_detections[:, 4])
                l_detections[:, 4] = -1

                center_x = center_points[:, 0][:, np.newaxis]
                center_y = center_points[:, 1][:, np.newaxis]
                l_left_x = l_left_x[np.newaxis, :]
                l_right_x = l_right_x[np.newaxis, :]
                l_top_y = l_top_y[np.newaxis, :]
                l_bottom_y = l_bottom_y[np.newaxis, :]

                ind_lx = (center_x - l_left_x) > 0
                ind_rx = (center_x - l_right_x) < 0
                ind_ty = (center_y - l_top_y) > 0
                ind_by = (center_y - l_bottom_y) < 0
                ind_cls = (center_points[:, 2][:, np.newaxis] -
                           l_detections[:, -1][np.newaxis, :]) == 0
                ind_l_new_score = np.max(
                    ((ind_lx + 0) & (ind_rx + 0) & (ind_ty + 0) &
                     (ind_by + 0) & (ind_cls + 0)),
                    axis=0) == 1
                index_l_new_score = np.argmax(
                    ((ind_lx + 0) & (ind_rx + 0) & (ind_ty + 0) & (ind_by + 0)
                     & (ind_cls + 0))[:, ind_l_new_score],
                    axis=0)
                l_detections[:, 4][ind_l_new_score] = (
                    l_temp_score[ind_l_new_score] * 2 +
                    center_points[index_l_new_score, 3]) / 3

                detections = np.concatenate([l_detections, s_detections],
                                            axis=0)
                detections = detections[np.argsort(-detections[:, 4])]
                classes = detections[..., -1]

                #nms
                keep_inds = (detections[:, 4] > 0)
                #keep_inds  = (detections[:, 4] > )
                detections = detections[keep_inds]
                classes = classes[keep_inds]
                nms_inds = py_nms(detections[:, 0:5], 0.5)
                detections = detections[nms_inds]
                classes = classes[nms_inds]

                #gt
                # det_gt = self.gt[0]['gt_det']
                # detections = np.ones((len(det_gt), 8))
                # for i in range(len(det_gt)):
                #   detections[i][:4] = det_gt[i].numpy()
                #   detections[i][7] = 0
                #
                # if detections.shape[0] == 0:
                #   self.non_person += 1
                #   print("nonperson"+str(self.non_person))
                #   self.last_dets[0][0, :, 4] = 0
                #   self.last_dets[0] = np.zeros(shape=(1, 1, 6+34))
                #   if return_time:
                #     return output, self.last_dets[0], forward_time
                #   else:
                #     return output, self.last_dets[0]

                kps = multi_pose_decode_c1(output,
                                           detections,
                                           output['hps'],
                                           hm_hp=hm_hp,
                                           hp_offset=hp_offset,
                                           K=self.opt.K).data.cpu().numpy()[0]

                num_j = kps.shape[1]
                dets = np.zeros(shape=(1, detections.shape[0], 6 + num_j))
                dets[0, :, 0:5] = detections[:, 0:5]
                dets[0, :, 5:5 + num_j] = kps
                dets[0, :, -1] = detections[:, -1]

                # top_bboxes[image_id] = {}
                # for j in range(categories):
                #     keep_inds = (classes == j)
                #     top_bboxes[image_id][j + 1] = detections[keep_inds][:, 0:7].astype(np.float32)
                #     if merge_bbox:
                #         soft_nms_merge(top_bboxes[image_id][j + 1], Nt=nms_threshold, method=nms_algorithm, weight_exp=weight_exp)
                #     else:
                #         soft_nms(top_bboxes[image_id][j + 1], Nt=nms_threshold, method=nms_algorithm)
                #     top_bboxes[image_id][j + 1] = top_bboxes[image_id][j + 1][:, 0:5]
                #
                # scores = np.hstack([
                #     top_bboxes[image_id][j][:, -1]
                #     for j in range(1, categories + 1)
                # ])
                # if len(scores) > max_per_image:
                #     kth    = len(scores) - max_per_image
                #     thresh = np.partition(scores, kth)[kth]
                #     for j in range(1, categories + 1):
                #         keep_inds = (top_bboxes[image_id][j][:, -1] >= thresh)
                #         top_bboxes[image_id][j] = top_bboxes[image_id][j][keep_inds]

                self.last_dets[0] = dets
                if return_time:
                    return output, dets, forward_time
                else:
                    return output, dets

            else:
                output = self.model(images)[-1]
                output['hm'] = output['hm'].sigmoid_()
                if self.opt.hm_hp and not self.opt.mse_loss:
                    output['hm_hp'] = output['hm_hp'].sigmoid_()

                reg = output['reg'] if self.opt.reg_offset else None
                hm_hp = output['hm_hp'] if self.opt.hm_hp else None
                hp_offset = output[
                    'hp_offset'] if self.opt.reg_hp_offset else None
                torch.cuda.synchronize()
                forward_time = time.time()

                if self.opt.flip_test:
                    output['hm'] = (output['hm'][0:1] +
                                    flip_tensor(output['hm'][1:2])) / 2
                    output['wh'] = (output['wh'][0:1] +
                                    flip_tensor(output['wh'][1:2])) / 2
                    output['hps'] = (output['hps'][0:1] + flip_lr_off(
                        output['hps'][1:2], self.flip_idx)) / 2
                    hm_hp = (hm_hp[0:1] + flip_lr(hm_hp[1:2], self.flip_idx)) / 2 \
                            if hm_hp is not None else None
                    reg = reg[0:1] if reg is not None else None
                    hp_offset = hp_offset[
                        0:1] if hp_offset is not None else None

                dets = multi_pose_decode(output['hm'],
                                         output['wh'],
                                         output['hps'],
                                         reg=reg,
                                         hm_hp=hm_hp,
                                         hp_offset=hp_offset,
                                         K=self.opt.K)

                if return_time:
                    return output, dets, forward_time
                else:
                    return output, dets
Beispiel #6
0
    def process(self, images, return_time=False):
        with torch.no_grad():
            torch.cuda.synchronize()
            output = self.model(images)[-1]
            # hm_x, hm_hp_x, wh, hps, reg, hp_offset, hm_max, hm_hp_max = self.model(images)
            # hm_x, hm_hp_x, wh, hps, reg, hp_offset = self.model(images)
            #
            # import planer
            # from planer import read_net
            # import cupy as cp
            # images_cp = cp.asarray(images.cpu().data.numpy())
            # pal = planer.core(cp)
            # net = read_net('res18')
            #
            # hm_x, hm_hp_x, wh, hps, reg, hp_offset = net(images_cp)
            # hm_x = torch.tensor(cp.asnumpy(hm_x).astype('float32')).cuda()
            # hm_hp_x = torch.tensor(cp.asnumpy(hm_hp_x).astype('float32')).cuda()
            # wh = torch.tensor(cp.asnumpy(wh).astype('float32')).cuda()
            # hps = torch.tensor(cp.asnumpy(hps).astype('float32')).cuda()
            # reg = torch.tensor(cp.asnumpy(reg).astype('float32')).cuda()
            # hp_offset = torch.tensor(cp.asnumpy(hp_offset).astype('float32')).cuda()

            # output = {'hm': hm_x, 'hm_hp': hm_hp_x, 'wh': wh, 'hps': hps, 'reg': reg,
            #           'hp_offset': hp_offset,
            #           # 'hm_max': hm_max, 'hm_hp_max': hm_hp_max
            #           }

            output['hm'] = output['hm'].sigmoid_()
            # output['hm_max'] = output['hm_max'].sigmoid_()
            if self.opt.hm_hp and not self.opt.mse_loss:
                output['hm_hp'] = output['hm_hp'].sigmoid_()
                # output['hm_hp_max'] = output['hm_hp_max'].sigmoid_()

            reg = output['reg'] if self.opt.reg_offset else None
            hm_hp = output['hm_hp'] if self.opt.hm_hp else None
            hp_offset = output['hp_offset'] if self.opt.reg_hp_offset else None
            torch.cuda.synchronize()
            forward_time = time.time()

            if self.opt.flip_test:
                output['hm'] = (output['hm'][0:1] +
                                flip_tensor(output['hm'][1:2])) / 2
                output['wh'] = (output['wh'][0:1] +
                                flip_tensor(output['wh'][1:2])) / 2
                output['hps'] = (output['hps'][0:1] + flip_lr_off(
                    output['hps'][1:2], self.flip_idx)) / 2
                hm_hp = (hm_hp[0:1] + flip_lr(hm_hp[1:2], self.flip_idx)) / 2 \
                        if hm_hp is not None else None
                reg = reg[0:1] if reg is not None else None
                hp_offset = hp_offset[0:1] if hp_offset is not None else None

            dets = multi_pose_decode(
                output['hm'],
                output['wh'],
                output['hps'],
                # output['hm_max'], output['hm_hp_max'],
                reg=reg,
                hm_hp=hm_hp,
                hp_offset=hp_offset,
                K=self.opt.K)

        if return_time:
            return output, dets, forward_time
        else:
            return output, dets
Beispiel #7
0
  def process(self, images, return_time=False):
    with torch.no_grad():
      torch.cuda.synchronize()
      output = self.model(images)[-1]
      if self.opt.mdn:
        mdn_logits = output['mdn_logits']
        mdn_pi = torch.clamp(torch.nn.Softmax(dim=1)(mdn_logits), 1e-4, 1.-1e-4)  
        mdn_sigma= torch.clamp(torch.nn.ELU()(output['mdn_sigma'])+self.opt.mdn_min_sigma,1e-4,1e5)
        mdn_mu = output['hps']

        # print("mdn_pi.shape:",mdn_pi.shape)
        # print("mdn_mu.shape:",mdn_mu.shape)
        # print("mdn_sigma.shape:",mdn_sigma.shape)
        (BS,_,H,W) = mdn_sigma.shape

        if self.opt.mdn_limit_comp is not None:
          M= mdn_pi.shape[1]
          C = mdn_mu.shape[1]//M
          cid=self.opt.mdn_limit_comp
          mdn_pi = mdn_pi[:,cid:cid+1,:,:]
          mdn_sigma=mdn_sigma[:,2*cid:2*cid+2,:,:]
          mdn_mu = mdn_mu[:,C*cid:C*cid+C,:,:]

        M= mdn_pi.shape[1]
        mdn_sigma = torch.reshape(mdn_sigma, (BS,M,2,H,W))
        C = mdn_mu.shape[1]//M
        mdn_mu = torch.reshape(mdn_mu, (BS,M,C,H,W))

        if self.opt.mdn_48:
          central = mdn_pi * torch.reciprocal(mdn_sigma[:,:,0,:,:])**C * torch.reciprocal(mdn_sigma[:,:,1,:,:])**C
          pi_max,pi_max_idx = torch.max(central,1)
        else:
          pi_max,pi_max_idx = torch.max(mdn_pi,1)
        if self.opt.mdn_max or self.opt.mdn_48:
          a = pi_max_idx.unsqueeze(1).repeat(1,C,1,1).reshape(BS,1,C,H,W)
          hps = torch.gather(mdn_mu,1,a).squeeze(1)

          a = pi_max_idx.unsqueeze(1).repeat(1,2,1,1).reshape(BS,1,2,H,W)
          sigmas = torch.gather(mdn_sigma,1,a).squeeze(1)
        elif self.opt.mdn_sum:
          hps = torch.sum(mdn_mu*mdn_pi.unsqueeze(2),1)
          sigmas = torch.sum(mdn_sigma*mdn_pi.unsqueeze(2),1)

        output.update({'hps':hps})
        #if self.opt.debug == 4:
        output.update({'mdn_max_idx':pi_max_idx.unsqueeze(1),
                       'mdn_sigmas':sigmas,
                       'mdn_max_pi':pi_max.unsqueeze(1)
                       })
        
      output['hm'] = output['hm'].sigmoid_()
      if self.opt.hm_hp:
        output['hm_hp'] = output['hm_hp'].sigmoid_()

      reg = output['reg'] if self.opt.reg_offset else None
      hm_hp = output['hm_hp'] if self.opt.hm_hp else None
      hp_offset = output['hp_offset'] if self.opt.reg_hp_offset else None
      torch.cuda.synchronize()
      forward_time = time.time()
      
      if self.opt.flip_test or self.opt.flip_test_max:
        output['hm'] = (output['hm'][0:1] + flip_tensor(output['hm'][1:2])) / 2
        output['wh'] = (output['wh'][0:1] + flip_tensor(output['wh'][1:2])) / 2
        hm_hp = (hm_hp[0:1] + flip_lr(hm_hp[1:2], self.flip_idx)) / 2 \
                if hm_hp is not None else None
        reg = reg[0:1] if reg is not None else None
        hp_offset = hp_offset[0:1] if hp_offset is not None else None
        if self.opt.mdn:
          output['mdn_max_idx'][1:2] = flip_tensor(output['mdn_max_idx'][1:2])
          output['mdn_max_pi'][1:2] = flip_tensor(output['mdn_max_pi'][1:2])
          output['mdn_sigmas'][1:2] = flip_tensor(output['mdn_sigmas'][1:2])

      if self.opt.flip_test:
        output['hps'] = (output['hps'][0:1] + 
          flip_lr_off(output['hps'][1:2], self.flip_idx)) / 2
      
        if self.opt.mdn:
          output['mdn_sigmas'] = (output['mdn_sigmas'][0:1] + output['mdn_sigmas'][1:2] ) / 2
     
      elif self.opt.flip_test_max:
        if self.opt.mdn:

          output['hps'][1:2] = flip_lr_off(output['hps'][1:2], self.flip_idx)

          #print("output['mdn_max_pi'].shape:",output['mdn_max_pi'].shape)
          _,pi_max_idx = torch.max(output['mdn_max_pi'],0)
          _,_,H,W =output['hps'].shape 
          a = pi_max_idx.unsqueeze(0).repeat(1,34,1,1).reshape(1,34,H,W)
          output['hps']= torch.gather(output['hps'],0,a)
          a = pi_max_idx.unsqueeze(0).repeat(1,2,1,1).reshape(1,2,H,W)
          output['mdn_sigmas']= torch.gather(output['mdn_sigmas'],0,a)

      dets = multi_pose_decode(
        output['hm'], output['wh'], output['hps'],
        reg=reg, hm_hp=hm_hp, hp_offset=hp_offset, K=self.opt.K,
        mdn_max_idx=output.get('mdn_max_idx'),
        mdn_max_pi=output.get('mdn_max_pi'),
        mdn_sigmas=output.get('mdn_sigmas'))

    if return_time:
      return output, dets, forward_time
    else:
      return output, dets