def train(train_loader, net, optimizer, epoch, visualizer, idx, opt):
    # batch_time = AverageMeter()
    # data_time = AverageMeter()
    losses = AverageMeter()
    pckhs = AverageMeter()
    pckhs_origin_res = AverageMeter()
    # switch to train mode
    net.train()

    # end = time.time()
    for i, (img, heatmap, c, s, r, grnd_pts,
            normalizer) in enumerate(train_loader):
        quan_op.quantization()
        # """measure data loading time"""
        # data_time.update(time.time() - end)

        # input and groundtruth
        img_var = torch.autograd.Variable(img)
        heatmap = heatmap.cuda(async=True)
        target_var = torch.autograd.Variable(heatmap)

        # output and loss
        # output1, output2 = net(img_var)
        # loss = (output1 - target_var) ** 2 + (output2 - target_var) ** 2
        output = net(img_var)
        # exit()
        # print(type(output))
        # print(len(output))
        loss = 0
        for per_out in output:
            tmp_loss = (per_out - target_var)**2
            loss = loss + tmp_loss.sum() / tmp_loss.numel()

        # gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        quan_op.restore()
        quan_op.updateQuanGradWeight()
        optimizer.step()

        # """measure optimization time"""
        # batch_time.update(time.time() - end)
        # end = time.time()
        # print log
        losses.update(loss.data[0])

        pckh = Evaluation.accuracy(output[-1].data.cpu(),
                                   target_var.data.cpu(), idx)
        pckhs.update(pckh[0])
        pckh_origin_res = Evaluation.accuracy_origin_res(
            output[-1].data.cpu(), c, s, [64, 64], grnd_pts, normalizer, r)
        pckhs_origin_res.update(pckh_origin_res[0])

        loss_dict = OrderedDict([('loss', losses.avg), ('pckh', pckhs.avg),
                                 ('pckh_origin_res', pckhs_origin_res.avg)])
        if i % opt.print_freq == 0 or i == len(train_loader) - 1:
            visualizer.print_log(epoch, i, len(train_loader), value1=loss_dict)
        # if i == 1:
        #     break
    return losses.avg, pckhs_origin_res.avg
def validate(val_loader, net, epoch, visualizer, idx, joint_flip_index,
             num_classes):
    batch_time = AverageMeter()
    losses = AverageMeter()
    pckhs = AverageMeter()
    pckhs_origin_res = AverageMeter()
    img_batch_list = []
    pts_batch_list = []
    predictions = torch.Tensor(val_loader.dataset.__len__(), num_classes, 2)

    # switch to evaluate mode
    net.eval()

    # end = time.time()
    quan_op.quantization()
    for i, (img, heatmap, center, scale, rot, grnd_pts, normalizer,
            index) in enumerate(val_loader):
        # input and groundtruth
        input_var = torch.autograd.Variable(img, volatile=True)

        heatmap = heatmap.cuda(async=True)
        target_var = torch.autograd.Variable(heatmap)

        # output and loss
        # output1, output2 = net(input_var)
        # loss = (output1 - target_var) ** 2 + (output2 - target_var) ** 2
        output1 = net(input_var)
        loss = 0
        for per_out in output1:
            tmp_loss = (per_out - target_var)**2
            loss = loss + tmp_loss.sum() / tmp_loss.numel()

        # flipping the image
        img_flip = img.numpy()[:, :, :, ::-1].copy()
        img_flip = torch.from_numpy(img_flip)
        input_var = torch.autograd.Variable(img_flip, volatile=True)
        # output11, output22 = net(input_var)
        output2 = net(input_var)
        output2 = HumanAug.flip_channels(output2[-1].data.cpu())
        output2 = HumanAug.shuffle_channels_for_horizontal_flipping(
            output2, joint_flip_index)
        output = (output1[-1].data.cpu() + output2) / 2

        # calculate measure
        # pred_pts = HumanPts.heatmap2pts(output)  # b x L x 2
        # pts = HumanPts.heatmap2pts(target_var.cpu().data)
        # pckh = HumanAcc.approx_PCKh(pred_pts, pts, idx, heatmap.size(3))  # b -> 1
        pckh = Evaluation.accuracy(output, target_var.data.cpu(), idx)
        pckhs.update(pckh[0])
        pckh_origin_res = Evaluation.accuracy_origin_res(
            output, center, scale, [64, 64], grnd_pts, normalizer, rot)
        pckhs_origin_res.update(pckh_origin_res[0])
        # """measure elapsed time"""
        # batch_time.update(time.time() - end)
        # end = time.time()

        # print log
        losses.update(loss.data[0])
        loss_dict = OrderedDict([('loss', losses.avg), ('pckh', pckhs.avg),
                                 ('pckh_origin_res', pckhs_origin_res.avg)])
        visualizer.print_log(epoch, i, len(val_loader), value1=loss_dict)
        # img_batch_list.append(img)
        # pts_batch_list.append(pred_pts*4.)
        preds = Evaluation.final_preds(output, center, scale, [64, 64], rot)
        for n in range(output.size(0)):
            predictions[index[n], :, :] = preds[n, :, :]

        # if i == 1:
        #     break
    quan_op.restore()
    return losses.avg, pckhs_origin_res.avg, predictions