def main():
    if CONFIG["CUDA"]:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cpu")

    weight_name = CONFIG["model"]["pretrained_weight"]
    model_dict = torch.load(weight_name)
    
    source_net = PoseEstimationWithMobileNet()
    target_net = PoseEstimationWithMobileNet()

    load_state(source_net, model_dict)
    load_state(target_net, model_dict)

    discriminator = Discriminator()
    criterion = nn.BCELoss()

    source_net = source_net.cuda(CONFIG["GPU"]["source_net"])
    target_net = target_net.cuda(CONFIG["GPU"]["target_net"])
    discriminator = discriminator.to(device)
    criterion = criterion.to(device)

    optimizer_tg = torch.optim.Adam(target_net.parameters(),
                                   lr=CONFIG["training_setting"]["t_lr"])
    optimizer_d = torch.optim.Adam(discriminator.parameters(),
                                  lr=CONFIG["training_setting"]["d_lr"])

    dataset = ADDADataset()
    dataloader = DataLoader(dataset, CONFIG["dataset"]["batch_size"], shuffle=True, num_workers=0)

    trainer = Trainer(source_net, target_net, discriminator, 
                     dataloader, optimizer_tg, optimizer_d, criterion, device)
    trainer.train()
    def __init__(self, checkpoint_path, device,
                 img_mean=np.array([128, 128, 128], dtype=np.float32),
                 img_scale=np.float32(1/255),
                 use_tensorrt=False):
        from models.with_mobilenet import PoseEstimationWithMobileNet
        from modules.load_state import load_state
        self.img_mean = img_mean
        self.img_scale = img_scale
        self.device = 'cpu'
        if device != 'CPU':
            if torch.cuda.is_available():
                self.device = torch.device('cuda:0')
            else:
                print('No CUDA device found, inferring on CPU')

        net = PoseEstimationWithMobileNet()
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        if use_tensorrt:
            from torch2trt import TRTModule
            net = TRTModule()
            net.load_state_dict(checkpoint)
        else:
            load_state(net, checkpoint)
            net = net.to(self.device)
        net.eval()
        self.net = net
Exemple #3
0
def main():

    net = PoseEstimationWithMobileNet()
    checkpoint = torch.load("models/checkpoint_iter_370000.pth",
                            map_location='cpu')
    load_state(net, checkpoint)
    net = net.cuda()
    done = threading.Event()

    with anki_vector.AsyncRobot() as robot:
        robot.camera.init_camera_feed()
        robot.camera.image_streaming_enabled()

        # preparing robot pose ready
        robot.behavior.set_head_angle(degrees(25.0))
        robot.behavior.set_lift_height(0.0)

        #events for detection and new camera feed
        robot.events.subscribe(on_new_raw_camera_image,
                               events.Events.new_raw_camera_image, net)
        robot.events.subscribe_by_name(on_robot_observed_touch,
                                       event_name='touched')

        print(
            "------ waiting for camera events, press ctrl+c to exit early ------"
        )

        try:
            if not done.wait(timeout=600):
                print("------ Did not receive a new camera image! ------")
        except KeyboardInterrupt:
            pass
Exemple #4
0
 def __init__(self):
     self.name = 'OpenPose'
     net = PoseEstimationWithMobileNet()
     checkpoint = torch.load('./checkpoint_iter_370000.pth',
                             map_location='cpu')
     load_state(net, checkpoint)
     self.net = net.eval()
     if envars.USE_GPU():
         self.net = self.net.cuda()
     self.stride = 8
     self.upsample_ratio = 4
     self.height_size = 256
     self.kpt_names = [
         'nose', 'neck', 'r_sho', 'r_elb', 'r_wri', 'l_sho', 'l_elb',
         'l_wri', 'r_hip', 'r_knee', 'r_ank', 'l_hip', 'l_knee', 'l_ank',
         'r_eye', 'l_eye', 'r_ear', 'l_ear'
     ]
     self.connections = [('nose', 'r_eye'), ('r_eye', 'r_ear'),
                         ('nose', 'l_eye'), ('l_eye', 'l_ear'),
                         ('nose', 'neck'), ('neck', 'r_sho'),
                         ('r_sho', 'r_elb'), ('r_elb', 'r_wri'),
                         ('neck', 'l_sho'), ('l_sho', 'l_elb'),
                         ('l_elb', 'l_wri'), ('neck', 'r_hip'),
                         ('r_hip', 'r_knee'), ('r_knee', 'r_ank'),
                         ('neck', 'l_hip'), ('l_hip', 'l_knee'),
                         ('l_knee', 'l_ank')]
Exemple #5
0
    def __init__(self, checkpoint_path, scale=256.):
        super().__init__()
        self.scale = scale

        pose_model = PoseEstimationWithMobileNet()
        state_dict = torch.load(checkpoint_path)
        load_state(pose_model, state_dict)
        self.pose_model = pose_model
def init_pose(checkpoint_path):
    net = PoseEstimationWithMobileNet()
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    load_state(net, checkpoint)

    net.eval()
    net = net.cuda()

    env = [net, []]

    return None, env
Exemple #7
0
def openpose_to_jit():

    x = torch.randn(1,3,256,456)

    net = PoseEstimationWithMobileNet().cpu()
    checkpoint = torch.load(r'.\weights\checkpoint_iter_370000.pth', map_location='cpu')
    load_state(net, checkpoint)
    net.eval()
    net(x)
    script_model = torch.jit.trace(net, x)
    script_model.save('test.jit')
def main1():
    parser = argparse.ArgumentParser(
        description='''Lightweight human pose estimation python demo.
                       This is just for quick results preview.
                       Please, consider c++ demo for the best performance.''')
    parser.add_argument('--checkpoint-path',
                        type=str,
                        required=True,
                        help='path to the checkpoint')
    parser.add_argument('--height-size',
                        type=int,
                        default=256,
                        help='network input layer height size')
    parser.add_argument('--video',
                        type=str,
                        default='',
                        help='path to video file or camera id')
    parser.add_argument('--images',
                        nargs='+',
                        default='',
                        help='path to input image(s)')
    parser.add_argument('--images_dir',
                        default='',
                        help='folderpath to input image(s)')
    parser.add_argument('--cpu',
                        action='store_true',
                        help='run network inference on cpu')
    parser.add_argument('--track',
                        type=int,
                        default=1,
                        help='track pose id in video')
    parser.add_argument('--smooth',
                        type=int,
                        default=1,
                        help='smooth pose keypoints')
    args = parser.parse_args()

    if args.video == '' and args.images == '' and args.images_dir == '':
        raise ValueError('Either --video or --image has to be provided')

    net = PoseEstimationWithMobileNet()
    checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
    load_state(net, checkpoint)

    frame_provider = ImageReader(args.images)
    if not args.images_dir == '':
        frame_provider = ImageReader(args.images_dir)
    if args.video != '':
        frame_provider = VideoReader(args.video)
    else:
        args.track = 0

    run_demo(net, frame_provider, args.height_size, args.cpu, args.track,
             args.smooth)
Exemple #9
0
def init(cpu = False):
    net = PoseEstimationWithMobileNet()

    checkpoint_path = "checkpoint_iter_370000.pth"
    checkpoint = torch.load(checkpoint_path, map_location='cpu') #load the existing model
    load_state(net, checkpoint)

    net = net.eval()
    if not cpu:
        net = net.cuda()

    return net
def run_demo(args, image_provider, height_size, cpu, track, smooth):

    net = PoseEstimationWithMobileNet()
    checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
    load_state(net, checkpoint)

    net = net.eval()
    if not cpu:
        net = net.cuda()

    stride = 8
    upsample_ratio = 4
    num_keypoints = Pose.num_kpts
    previous_poses = []
    delay = 33
    for d in image_provider:
        img, image_name = d["image"], d["image_name"]
        orig_img = img.copy()
        heatmaps, pafs, scale, pad = infer_fast(net, img, height_size, stride,
                                                upsample_ratio, cpu)

        total_keypoints_num = 0
        all_keypoints_by_type = []
        for kpt_idx in range(num_keypoints):  # 19th for bg
            total_keypoints_num += extract_keypoints(heatmaps[:, :, kpt_idx],
                                                     all_keypoints_by_type,
                                                     total_keypoints_num)

        pose_entries, all_keypoints = group_keypoints(all_keypoints_by_type,
                                                      pafs,
                                                      demo=True)
        for kpt_id in range(all_keypoints.shape[0]):
            all_keypoints[kpt_id, 0] = (all_keypoints[kpt_id, 0] * stride /
                                        upsample_ratio - pad[1]) / scale
            all_keypoints[kpt_id, 1] = (all_keypoints[kpt_id, 1] * stride /
                                        upsample_ratio - pad[0]) / scale
        keypoints_out = []
        for n in range(len(pose_entries)):
            if len(pose_entries[n]) == 0:
                continue
            pose_keypoints = np.ones((num_keypoints, 3), dtype=np.float32) * -1
            for kpt_id in range(num_keypoints):
                if pose_entries[n][kpt_id] != -1.0:  # keypoint was found
                    pose_keypoints[kpt_id, 0] = int(
                        all_keypoints[int(pose_entries[n][kpt_id]), 0])
                    pose_keypoints[kpt_id, 1] = int(
                        all_keypoints[int(pose_entries[n][kpt_id]), 1])
                    pose_keypoints[kpt_id,
                                   2] = 0.94 if pose_keypoints[kpt_id,
                                                               0] != -1 else 0
            keypoints_out.append(pose_keypoints)
        save_json(image_name, keypoints_out, args)
Exemple #11
0
 def __init__(self, model_path: str):
     self.device = torch.device(
         "cuda" if torch.cuda.is_available() else "cpu")
     self.net = PoseEstimationWithMobileNet()
     self.net.to(self.device)
     checkpoint = torch.load(model_path)
     load_state(self.net, checkpoint)
     self.net.eval()
     self.image = None
     self.avg_heatmap = None
     self.avg_paf = None
     self.track = True
     self.stride = 8
     self.upsample_ratio = 4
     self.height_size = 256
     self.smooth = 1
     self.num_keypoints = Pose.num_kpts
Exemple #12
0
    def _pose_dect_init(self, device):
        """Initialize the pose detection model.

        Arguments:
            device {torch.device}: device to implement the models on.

        Returns:
            PoseEstimationWithMobileNet: initialized OpenPose model.
        """

        weight_path = self.__params.pose_weights
        model = PoseEstimationWithMobileNet()
        weight = torch.load(weight_path, map_location='cpu')
        load_state(model, weight)
        model = model.eval()
        if device.type != 'cpu':
            model = model.cuda()

        return model
Exemple #13
0
def loadEmotion():
    print("-Loading pose estimation neural net...")
    poseEstNet = PoseEstimationWithMobileNet()
    poseEstNet = poseEstNet.cuda()
    checkpoint = torch.load(
        "../lightweight-human-pose-estimation.pytorch/checkpoint_iter_370000.pth",
        map_location='cpu')
    load_state(poseEstNet, checkpoint)
    print("-Pose estimation neural net loaded")

    print("-Opening body language decoders...")
    bodymove = BLmovements()
    bodydecode = BLdecode()
    print("-Body language decoder loaded")

    print("-Loading facial emotion neural net...")
    facialEmotionNet, faceRecNet, image_size = ssd_infer.load()
    print("-Facial emotion neural net loaded")

    return poseEstNet, bodymove, bodydecode, facialEmotionNet, faceRecNet, image_size
def upload():
    if request.method == 'POST':
        file = request.files['file']
        extension = os.path.splitext(file.filename)[1]
        f_name = "lastphoto.jpg"
        file.save(os.path.join('uploads', f_name))

        image = cv2.imread("uploads/lastphoto.jpg", cv2.IMREAD_COLOR)

        net = PoseEstimationWithMobileNet()
        checkpoint = torch.load("checkpoint_iter_370000.pth.tar",
                                map_location='cpu')
        load_state(net, checkpoint)

        return json.dumps(
            {
                'filename': f_name,
                'humans': run_demo(net, image, 256, True)
            },
            cls=NumpyEncoder)
from modules.load_state import load_state


def convert_to_onnx(net, output_name):
    input = torch.zeros(1, 3, 256, 448)
    input_names = ['data']
    output_names = ['features', 'heatmaps', 'pafs']
    model_trt = torch2trt(net, [input], fp16_mode=True)
    torch.save(model_trt.state_dict(), output_name)
    # torch.onnx.export(net, input, output_name, verbose=True, input_names=input_names, output_names=output_names)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint-path',
                        type=str,
                        default='models/human-pose-estimation-3d.pth',
                        help='path to the checkpoint')
    parser.add_argument('--output-name',
                        type=str,
                        default='human-pose-estimation-3d-trt.pth',
                        help='name of output model in ONNX format')
    args = parser.parse_args()

    net = PoseEstimationWithMobileNet(is_convertible_by_mo=True)
    checkpoint = torch.load(args.checkpoint_path)
    load_state(net, checkpoint)

    convert_to_onnx(net, args.output_name)
    print('=====================done=====================')
                        default='',
                        help='path to input image(s)')
    parser.add_argument('--cpu',
                        action='store_true',
                        help='run network inference on cpu')
    parser.add_argument('--track',
                        type=int,
                        default=1,
                        help='track pose id in video')
    parser.add_argument('--smooth',
                        type=int,
                        default=1,
                        help='smooth pose keypoints')
    args = parser.parse_args()

    if args.video == '' and args.images == '':
        raise ValueError('Either --video or --image has to be provided')

    net = PoseEstimationWithMobileNet(num_heatmaps=25, num_pafs=50)
    checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
    load_state(net, checkpoint)

    frame_provider = ImageReader(args.images)
    if args.video != '':
        frame_provider = VideoReader(args.video)
    else:
        args.track = 0

    run_demo(net, frame_provider, args.height_size, args.cpu, args.track,
             args.smooth)
Exemple #17
0
    def callback(self, data):
        try:
            cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8")
        except CvBridgeError as e:
            print(e)

            ## Rescale Image size
        rescale_factor = 1
        width = int(cv_image.shape[1] * rescale_factor)
        height = int(cv_image.shape[0] * rescale_factor)
        dim = (width, height)
        resized_img = cv2.resize(cv_image, dim)

        net = PoseEstimationWithMobileNet()
        checkpoint = torch.load(
            "/home/zheng/lightweight-human-pose-estimation.pytorch/checkpoint_iter_370000.pth",
            map_location='cpu')
        load_state(net, checkpoint)
        height_size = 256
        net = net.eval()
        net = net.cuda()
        net.eval()

        stride = 8
        upsample_ratio = 4
        num_keypoints = Pose.num_kpts
        previous_poses = []
        delay = 33
        # img = cv2.imread("/home/zheng/lightweight-human-pose-estimation.pytorch/data/image_1400.jpg")
        img = asarray(cv_image)
        orig_img = img
        heatmaps, pafs, scale, pad = infer_fast(net,
                                                img,
                                                height_size,
                                                stride,
                                                upsample_ratio,
                                                cpu="store_true")

        total_keypoints_num = 0
        all_keypoints_by_type = []
        for kpt_idx in range(num_keypoints):  # 19th for bg
            total_keypoints_num += extract_keypoints(heatmaps[:, :, kpt_idx],
                                                     all_keypoints_by_type,
                                                     total_keypoints_num)

        pose_entries, all_keypoints = group_keypoints(all_keypoints_by_type,
                                                      pafs,
                                                      demo=True)

        for kpt_id in range(all_keypoints.shape[0]):
            all_keypoints[kpt_id, 0] = (all_keypoints[kpt_id, 0] * stride /
                                        upsample_ratio - pad[1]) / scale
            all_keypoints[kpt_id, 1] = (all_keypoints[kpt_id, 1] * stride /
                                        upsample_ratio - pad[0]) / scale
        current_poses = []

        ##   Collect all keypoint in numpy array to send it to Ros"
        pose_keypoints_ros_data = np.zeros(16)
        my_array_for_publishing = Float32MultiArray()

        ####
        pose_keypoints = np.ones((num_keypoints, 2), dtype=np.int32) * -1
        for kpt_id in range(8):
            if pose_entries[0][kpt_id] != -1.0:  # keypoint was found
                pose_keypoints[kpt_id, 0] = int(
                    all_keypoints[int(pose_entries[0][kpt_id]), 0])
                pose_keypoints[kpt_id, 1] = int(
                    all_keypoints[int(pose_entries[0][kpt_id]), 1])

            pose = Pose(pose_keypoints, pose_entries[0][18])
            current_poses.append(pose)
            pose_keypoints_ros_data[2 * kpt_id] = pose.keypoints[kpt_id][0]
            pose_keypoints_ros_data[2 * kpt_id + 1] = pose.keypoints[kpt_id][1]
        for pose in current_poses:
            pose.draw(img)
        img = cv2.addWeighted(orig_img, 0.6, img, 0.4, 0)
        my_array_for_publishing.data = [
            pose_keypoints_ros_data[0],
            pose_keypoints_ros_data[1],
            pose_keypoints_ros_data[2],
            pose_keypoints_ros_data[3],
            pose_keypoints_ros_data[4],
            pose_keypoints_ros_data[5],
            pose_keypoints_ros_data[6],
            pose_keypoints_ros_data[7],
            pose_keypoints_ros_data[8],
            pose_keypoints_ros_data[9],
            pose_keypoints_ros_data[10],
            pose_keypoints_ros_data[11],
            pose_keypoints_ros_data[12],
            pose_keypoints_ros_data[13],
            pose_keypoints_ros_data[14],
            pose_keypoints_ros_data[15],
        ]
        # cv2.imshow('Lightweight Human Pose Estimation Python Demo', img)
        self.image_pub.publish(self.bridge.cv2_to_imgmsg(img, "bgr8"))
        self.keypts_pub.publish(my_array_for_publishing)
        # cv2.imwrite('/home/zheng/Bureau/image_1400_key.jpg',img)

        cv2.waitKey(2)
import torch
from models.with_mobilenet import PoseEstimationWithMobileNet  #my particular net architecture
from modules.load_state import load_state
from torch2trt import torch2trt  #import library
import time

#HERE IT IS HOW COMPILE AND SAVE A MODEL
checkpoint_path = '/home/nvidia/Documents/poseFINAL/checkpoints/body.pth'  #your trained weights path

net = PoseEstimationWithMobileNet()  #my particular net istance
checkpoint = torch.load(checkpoint_path, map_location='cuda')
load_state(net, checkpoint)  #load your trained weights path
net.cuda().eval()

data = torch.rand((
    1, 3, 256,
    344)).cuda()  #initialize a random tensor with the shape of your input data

#model_trt = torch2trt(net, [data]) #IT CREATES THE COMPILED VERSION OF YOUR MODEL, IT TAKES A WHILE

#torch.save(model_trt.state_dict(), 'net_trt.pth') #TO SAVE THE WEIGHTS OF THE COMPILED MODEL WICH ARE DIFFERENT FROM THE PREVIOUS ONES

#HERE IT IS HOW TO UPLOAD THE MODEL ONCE YOU HAVE COMPILED IT LIKE IN MY CASE THAT I HAVE ALREADY COMPILED IT

from torch2trt import TRTModule  #import a class

model_trt = TRTModule()  #the compiled model istance

model_trt.load_state_dict(torch.load(
    'net_trt.pth'))  #load the compiled weights in the compiled model
Exemple #19
0
def train(prepared_train_labels, train_images_folder, num_refinement_stages,
          base_lr, batch_size, batches_per_iter, num_workers, checkpoint_path,
          weights_only, from_mobilenet, checkpoints_folder, log_after,
          val_labels, val_images_folder, val_output_name, checkpoint_after,
          val_after):
    net = PoseEstimationWithMobileNet(num_refinement_stages)

    stride = 8
    sigma = 7
    path_thickness = 1
    dataset = CocoTrainDataset(prepared_train_labels,
                               train_images_folder,
                               stride,
                               sigma,
                               path_thickness,
                               transform=transforms.Compose([
                                   ConvertKeypoints(),
                                   Scale(),
                                   Rotate(pad=(128, 128, 128)),
                                   CropPad(pad=(128, 128, 128)),
                                   Flip()
                               ]))
    train_loader = DataLoader(dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers)

    optimizer = optim.Adam([
        {
            'params': get_parameters_conv(net.model, 'weight')
        },
        {
            'params': get_parameters_conv_depthwise(net.model, 'weight'),
            'weight_decay': 0
        },
        {
            'params': get_parameters_bn(net.model, 'weight'),
            'weight_decay': 0
        },
        {
            'params': get_parameters_bn(net.model, 'bias'),
            'lr': base_lr * 2,
            'weight_decay': 0
        },
        {
            'params': get_parameters_conv(net.cpm, 'weight'),
            'lr': base_lr
        },
        {
            'params': get_parameters_conv(net.cpm, 'bias'),
            'lr': base_lr * 2,
            'weight_decay': 0
        },
        {
            'params': get_parameters_conv_depthwise(net.cpm, 'weight'),
            'weight_decay': 0
        },
        {
            'params': get_parameters_conv(net.initial_stage, 'weight'),
            'lr': base_lr
        },
        {
            'params': get_parameters_conv(net.initial_stage, 'bias'),
            'lr': base_lr * 2,
            'weight_decay': 0
        },
        {
            'params': get_parameters_conv(net.refinement_stages, 'weight'),
            'lr': base_lr * 4
        },
        {
            'params': get_parameters_conv(net.refinement_stages, 'bias'),
            'lr': base_lr * 8,
            'weight_decay': 0
        },
        {
            'params': get_parameters_bn(net.refinement_stages, 'weight'),
            'weight_decay': 0
        },
        {
            'params': get_parameters_bn(net.refinement_stages, 'bias'),
            'lr': base_lr * 2,
            'weight_decay': 0
        },
    ],
                           lr=base_lr,
                           weight_decay=5e-4)

    num_iter = 0
    current_epoch = 0
    drop_after_epoch = [100, 200, 260]
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=drop_after_epoch,
                                               gamma=0.333)
    if checkpoint_path:
        checkpoint = torch.load(checkpoint_path)

        if from_mobilenet:
            load_from_mobilenet(net, checkpoint)
        else:
            load_state(net, checkpoint)
            if not weights_only:
                optimizer.load_state_dict(checkpoint['optimizer'])
                scheduler.load_state_dict(checkpoint['scheduler'])
                num_iter = checkpoint['iter']
                current_epoch = checkpoint['current_epoch']

    net = DataParallel(net).cuda()
    net.train()
    for epochId in range(current_epoch, 280):
        scheduler.step()
        total_losses = [0, 0] * (num_refinement_stages + 1
                                 )  # heatmaps loss, paf loss per stage
        batch_per_iter_idx = 0
        for batch_data in train_loader:
            if batch_per_iter_idx == 0:
                optimizer.zero_grad()

            # print("show imgs"
            #       , batch_data['keypoint_maps'].shape, batch_data['paf_maps'].shape
            #       , batch_data['keypoint_mask'].shape, batch_data['paf_mask'].shape
            #       , batch_data['mask'].shape, batch_data['image'].shape
            #       )
            # print("seg", batch_data['label']['segmentations'])
            print("batched images size", batch_data['image'].shape)

            vis.images(batch_data['image'][:, [2, 1, 0], ...] + 0.5,
                       4,
                       2,
                       "1",
                       opts=dict(title="img"))
            vis.images(batch_data['keypoint_mask'].permute(1, 0, 2, 3),
                       4,
                       2,
                       "2",
                       opts=dict(title="kp_mask"))
            vis.images(batch_data['paf_mask'].permute(1, 0, 2, 3),
                       4,
                       2,
                       "3",
                       opts=dict(title="paf_mask"))
            vis.images(batch_data['keypoint_maps'].permute(1, 0, 2, 3),
                       4,
                       2,
                       "4",
                       opts=dict(title="keypoint_maps"))
            vis.images(batch_data['paf_maps'].permute(1, 0, 2, 3),
                       4,
                       2,
                       "5",
                       opts=dict(title="paf_maps"))
            vis.images(batch_data['mask'].unsqueeze(0),
                       4,
                       2,
                       "6",
                       opts=dict(title="MASK"))

            images = batch_data['image'].cuda()
            keypoint_masks = batch_data['keypoint_mask'].cuda()
            paf_masks = batch_data['paf_mask'].cuda()
            keypoint_maps = batch_data['keypoint_maps'].cuda()
            paf_maps = batch_data['paf_maps'].cuda()

            pafs = batch_data['paf_maps'][0].permute(1, 2, 0).numpy()

            scale = 4
            img_p = np.zeros((pafs.shape[1] * 8, pafs.shape[0] * 8, 3),
                             dtype=np.uint8)
            # pafs[pafs < 0.07] = 0
            for idx in range(len(BODY_PARTS_PAF_IDS)):
                # print(pp, pafs.shape)
                pp = BODY_PARTS_PAF_IDS[idx]
                k_idx = BODY_PARTS_KPT_IDS[idx]
                cc = BODY_CONN_COLOR[idx]

                vx = pafs[:, :, pp[0]]
                vy = pafs[:, :, pp[1]]
                for i in range(pafs.shape[1]):
                    for j in range(pafs.shape[0]):
                        a = (i * 2 * scale, j * 2 * scale)
                        b = (2 * int((i + vx[j, i] * 3) * scale), 2 * int(
                            (j + vy[j, i] * 3) * scale))
                        if a[0] == b[0] and a[1] == b[1]:
                            continue

                        cv2.line(img_p, a, b, cc, 1)

                # break

            cv2.imshow("paf", img_p)
            key = cv2.waitKey(0)
            if key == 27:  # esc
                exit(0)

            stages_output = net(images)

            losses = []
            for loss_idx in range(len(total_losses) // 2):
                losses.append(
                    l2_loss(stages_output[loss_idx * 2], keypoint_maps,
                            keypoint_masks, images.shape[0]))
                losses.append(
                    l2_loss(stages_output[loss_idx * 2 + 1], paf_maps,
                            paf_masks, images.shape[0]))
                total_losses[loss_idx *
                             2] += losses[-2].item() / batches_per_iter
                total_losses[loss_idx * 2 +
                             1] += losses[-1].item() / batches_per_iter

            loss = losses[0]
            for loss_idx in range(1, len(losses)):
                loss += losses[loss_idx]
            loss /= batches_per_iter
            loss.backward()
            batch_per_iter_idx += 1
            if batch_per_iter_idx == batches_per_iter:
                optimizer.step()
                batch_per_iter_idx = 0
                num_iter += 1
            else:
                continue

            if num_iter % log_after == 0:
                print('Iter: {}'.format(num_iter))
                for loss_idx in range(len(total_losses) // 2):
                    print('\n'.join([
                        'stage{}_pafs_loss:     {}',
                        'stage{}_heatmaps_loss: {}'
                    ]).format(loss_idx + 1,
                              total_losses[loss_idx * 2 + 1] / log_after,
                              loss_idx + 1,
                              total_losses[loss_idx * 2] / log_after))
                for loss_idx in range(len(total_losses)):
                    total_losses[loss_idx] = 0
            if num_iter % checkpoint_after == 0:
                snapshot_name = '{}/checkpoint_iter_{}.pth'.format(
                    checkpoints_folder, num_iter)
                torch.save(
                    {
                        'state_dict': net.module.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'iter': num_iter,
                        'current_epoch': epochId
                    }, snapshot_name)
            if num_iter % val_after == 0:
                print('Validation...')
                evaluate(val_labels, val_output_name, val_images_folder, net)
                net.train()
Exemple #20
0
if __name__ == '__main__':
    # parser = argparse.ArgumentParser(
    #     description='''Lightweight human pose estimation python demo.
    #                    This is just for quick results preview.
    #                    Please, consider c++ demo for the best performance.''')
    # parser.add_argument('--checkpoint-path', type=str, required=True, help='path to the checkpoint')
    # parser.add_argument('--video', type=str, default='', help='path to video file or camera id')
    # args = parser.parse_args()

    # if args.video == '':
    #     raise ValueError('--video has to be provided')

    # net = PoseEstimationWithMobileNet(num_heatmaps=26, num_pafs=52)
    # checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
    # load_state(net, checkpoint)

    # frame_provider = VideoReader(args.video)

    # run_demo(net, frame_provider, 256, False, True, True)

    net = PoseEstimationWithMobileNet(num_heatmaps=26,
                                      num_pafs=52,
                                      num_refinement_stages=1)
    checkpoint = torch.load('body25_checkpoints/checkpoint_iter_465000.pth',
                            map_location='cpu')
    load_state(net, checkpoint)

    frame_provider = VideoReader('D:/projects/MotioNet/video/beyonce.mp4')

    run_demo(net, frame_provider, 256, False, True, True)
 def __init__(self, model_path):
     self.model = PoseEstimationWithMobileNet()
     checkpoint = torch.load(model_path, map_location='cpu')
     load_state(self.model, checkpoint)
     self.model = self.model.eval()
     self.model = self.model.cuda()
def convert_to_skelets(in_, out_, cpu=False, height_size=256):
    #   height_size - network input layer height size
    #   cpu - True if we would like to run in CPU
    print('start convert to skelets')
    # mask that shows - this is bed
    mask = cv2.imread(os.path.join('mask', 'mask.jpg'), 0)
    mask = cv2.normalize(mask,
                         None,
                         alpha=0,
                         beta=1,
                         norm_type=cv2.NORM_MINMAX,
                         dtype=cv2.CV_32F)
    net = PoseEstimationWithMobileNet()

    load_state(net, checkpoint)
    net = net.eval()
    if not cpu:
        net = net.cuda()

    stride = 8
    upsample_ratio = 4

    max_number = 963
    num_img = 0
    stream = cv2.VideoCapture("rtsp://*****:*****@62.140.233.76:554")
    #    for num in range(0, max_number + 1):
    while (True):
        #        frame = 'frame' + str(num) + '.jpg'
        #        img = cv2.imread(os.path.join(in_, frame), cv2.IMREAD_COLOR)

        r, img = stream.read()

        # cv2.destroyAllWindows()
        # find the place of the bed - and add border to it, so we can cut the unnecessary part
        # apply object detection and find bed
        # output is an image with black pixels of not bed, and white pixels of bed

        heatmaps, pafs, scale, pad = infer_fast(net, img, height_size, stride,
                                                upsample_ratio, cpu)

        total_keypoints_num = 0
        all_keypoints_by_type = []
        for kpt_idx in range(18):
            total_keypoints_num += extract_keypoints(heatmaps[:, :, kpt_idx],
                                                     all_keypoints_by_type,
                                                     total_keypoints_num)
        pose_entries, all_keypoints = group_keypoints(all_keypoints_by_type,
                                                      pafs)
        for kpt_id in range(all_keypoints.shape[0]):
            all_keypoints[kpt_id, 0] = (all_keypoints[kpt_id, 0] * stride /
                                        upsample_ratio - pad[1]) / scale
            all_keypoints[kpt_id, 1] = (all_keypoints[kpt_id, 1] * stride /
                                        upsample_ratio - pad[0]) / scale
        # how many persons in image
        num_persons = len(pose_entries)
        # num_img more than time_period - we delete first second and add the last second

        bones_detected = np.zeros(len(bones_to_detect))
        bones_xa = np.zeros(len(bones_to_detect))
        bones_ya = np.zeros(len(bones_to_detect))
        bones_xb = np.zeros(len(bones_to_detect))
        bones_yb = np.zeros(len(bones_to_detect))
        bones_in_bed = np.zeros(len(bones_to_detect))

        for n in range(num_persons):
            count_person_not_in_bed = 1
            for id_x in range(len(bones_to_detect)):
                bones_detected[id_x] = 0
                bones_xa[id_x] = 0
                bones_ya[id_x] = 0
                bones_xb[id_x] = 0
                bones_yb[id_x] = 0
                bones_in_bed[id_x] = 0
            if len(pose_entries[n]) == 0:
                continue
            for id_, part_id in enumerate(bones_to_detect):
                kpt_a_id = BODY_PARTS_KPT_IDS[part_id][0]
                global_kpt_a_id = pose_entries[n][kpt_a_id]
                kpt_b_id = BODY_PARTS_KPT_IDS[part_id][1]
                global_kpt_b_id = pose_entries[n][kpt_b_id]
                # if both points are detected
                if global_kpt_a_id != -1 and global_kpt_b_id != -1:
                    bones_xa[id_], bones_ya[id_] = all_keypoints[
                        int(global_kpt_a_id), 0:2]
                    bones_xb[id_], bones_yb[id_] = all_keypoints[
                        int(global_kpt_b_id), 0:2]
                    if mask[int(bones_ya[id_])][int(
                            bones_xa[id_])] == 1 and mask[int(
                                bones_yb[id_])][int(bones_xb[id_])] == 1:
                        bones_in_bed[id_] = 1
                    bones_detected[id_] = 1

            sum_bones = 0
            for id_, val in enumerate(bones_in_bed):
                sum_bones += val
            if sum_bones == len(bones_in_bed):
                # anomaly
                # we take mean vector of 2 vectors of bones 6 and 9
                bone_xa = (bones_xa[0] + bones_xa[2]) / 2
                bone_ya = (bones_ya[0] + bones_ya[2]) / 2
                bone_xb = (bones_xb[0] + bones_xb[2]) / 2
                bone_yb = (bones_yb[0] + bones_yb[2]) / 2
                x1 = bone_xb - bone_xa
                y1 = bone_yb - bone_ya
                x2 = 100
                y2 = 0
                global anomaly_checker
                alfa = math.acos(
                    (x1 * x2 + y1 * y2) /
                    (math.sqrt(x1**2 + y1**2) * math.sqrt(x2**2 + y2**2)))
                # if alfa is close to 90 degree - anomaly
                if min(abs(alfa - rad_90), abs(alfa - rad_270)) <= threshold:
                    print('num_persons', num_persons)
                    if num_persons == 1:
                        anomaly_checker = np.delete(anomaly_checker, 0)
                        anomaly_checker = np.append(anomaly_checker, 1)
                    cv2.imwrite(os.path.join('out_out', frame), img)
                if np.sum(anomaly_checker) >= SEC_WITHOUT_HELP:
                    print('ALARM!')

        num_img += 1

        if not os.path.exists(out_):
            os.mkdir(out_)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    print('done convert to skelets')
def train(prepared_train_labels, train_images_folder, num_refinement_stages,
          base_lr, batch_size, batches_per_iter, num_workers, checkpoint_path,
          weights_only, from_mobilenet, checkpoints_folder, log_after,
          val_labels, val_images_folder, val_output_name, checkpoint_after,
          val_after):
    net = PoseEstimationWithMobileNet(num_refinement_stages)

    stride = 8
    sigma = 7
    path_thickness = 1
    dataset = CocoTrainDataset(prepared_train_labels,
                               train_images_folder,
                               stride,
                               sigma,
                               path_thickness,
                               transform=transforms.Compose([
                                   ConvertKeypoints(),
                                   Scale(),
                                   Rotate(pad=(128, 128, 128)),
                                   CropPad(pad=(128, 128, 128)),
                                   Flip()
                               ]))
    train_loader = DataLoader(dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers)

    optimizer = optim.Adam([
        {
            'params': get_parameters_conv(net.model, 'weight')
        },
        {
            'params': get_parameters_conv_depthwise(net.model, 'weight'),
            'weight_decay': 0
        },
        {
            'params': get_parameters_bn(net.model, 'weight'),
            'weight_decay': 0
        },
        {
            'params': get_parameters_bn(net.model, 'bias'),
            'lr': base_lr * 2,
            'weight_decay': 0
        },
        {
            'params': get_parameters_conv(net.cpm, 'weight'),
            'lr': base_lr
        },
        {
            'params': get_parameters_conv(net.cpm, 'bias'),
            'lr': base_lr * 2,
            'weight_decay': 0
        },
        {
            'params': get_parameters_conv_depthwise(net.cpm, 'weight'),
            'weight_decay': 0
        },
        {
            'params': get_parameters_conv(net.initial_stage, 'weight'),
            'lr': base_lr
        },
        {
            'params': get_parameters_conv(net.initial_stage, 'bias'),
            'lr': base_lr * 2,
            'weight_decay': 0
        },
        {
            'params': get_parameters_conv(net.refinement_stages, 'weight'),
            'lr': base_lr * 4
        },
        {
            'params': get_parameters_conv(net.refinement_stages, 'bias'),
            'lr': base_lr * 8,
            'weight_decay': 0
        },
        {
            'params': get_parameters_bn(net.refinement_stages, 'weight'),
            'weight_decay': 0
        },
        {
            'params': get_parameters_bn(net.refinement_stages, 'bias'),
            'lr': base_lr * 2,
            'weight_decay': 0
        },
    ],
                           lr=base_lr,
                           weight_decay=5e-4)

    num_iter = 0
    current_epoch = 0
    drop_after_epoch = [100, 200, 260]
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=drop_after_epoch,
                                               gamma=0.333)
    if checkpoint_path:
        checkpoint = torch.load(checkpoint_path)

        if from_mobilenet:
            load_from_mobilenet(net, checkpoint)
        else:
            load_state(net, checkpoint)
            if not weights_only:
                optimizer.load_state_dict(checkpoint['optimizer'])
                scheduler.load_state_dict(checkpoint['scheduler'])
                num_iter = checkpoint['iter']
                current_epoch = checkpoint['current_epoch']

    net = DataParallel(net).cuda()
    net.train()
    for epochId in range(current_epoch, 280):
        scheduler.step()
        total_losses = [0, 0] * (num_refinement_stages + 1
                                 )  # heatmaps loss, paf loss per stage
        batch_per_iter_idx = 0
        for batch_data in train_loader:
            if batch_per_iter_idx == 0:
                optimizer.zero_grad()

            images = batch_data['image'].cuda()
            keypoint_masks = batch_data['keypoint_mask'].cuda()
            paf_masks = batch_data['paf_mask'].cuda()
            keypoint_maps = batch_data['keypoint_maps'].cuda()
            paf_maps = batch_data['paf_maps'].cuda()

            stages_output = net(images)

            losses = []
            for loss_idx in range(len(total_losses) // 2):
                losses.append(
                    l2_loss(stages_output[loss_idx * 2], keypoint_maps,
                            keypoint_masks, images.shape[0]))
                losses.append(
                    l2_loss(stages_output[loss_idx * 2 + 1], paf_maps,
                            paf_masks, images.shape[0]))
                total_losses[loss_idx *
                             2] += losses[-2].item() / batches_per_iter
                total_losses[loss_idx * 2 +
                             1] += losses[-1].item() / batches_per_iter

            loss = losses[0]
            for loss_idx in range(1, len(losses)):
                loss += losses[loss_idx]
            loss /= batches_per_iter
            loss.backward()
            batch_per_iter_idx += 1
            if batch_per_iter_idx == batches_per_iter:
                optimizer.step()
                batch_per_iter_idx = 0
                num_iter += 1
            else:
                continue

            if num_iter % log_after == 0:
                print('Iter: {}'.format(num_iter))
                for loss_idx in range(len(total_losses) // 2):
                    print('\n'.join([
                        'stage{}_pafs_loss:     {}',
                        'stage{}_heatmaps_loss: {}'
                    ]).format(loss_idx + 1,
                              total_losses[loss_idx * 2 + 1] / log_after,
                              loss_idx + 1,
                              total_losses[loss_idx * 2] / log_after))
                for loss_idx in range(len(total_losses)):
                    total_losses[loss_idx] = 0
            if num_iter % checkpoint_after == 0:
                snapshot_name = '{}/checkpoint_iter_{}.pth'.format(
                    checkpoints_folder, num_iter)
                torch.save(
                    {
                        'state_dict': net.module.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'iter': num_iter,
                        'current_epoch': epochId
                    }, snapshot_name)
            if num_iter % val_after == 0:
                print('Validation...')
                evaluate(val_labels, val_output_name, val_images_folder, net)
                net.train()
Exemple #24
0
def gen():
    """Video streaming generator function."""
    HOST = ''
    PORT = 8088

    emptyPoses = []

    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    print('Socket created')

    s.bind((HOST, PORT))
    print('Socket bind complete')
    s.listen(10)
    print('Socket now listening')

    net = PoseEstimationWithMobileNet()
    checkpoint = torch.load('checkpoint_iter_370000.pth', map_location='cpu')
    load_state(net, checkpoint)

    conn, addr = s.accept()
    print('ACCENPTED')

    data = b''  ### CHANGED
    payload_size = struct.calcsize("=L")  ### CHANGED

    stepA = False
    stepB = False
    global count
    count = 0
    sitAngle = 0
    stdupAngle = 0

    font = cv2.FONT_HERSHEY_SIMPLEX
    bottomLeftCornerOfText = (50, 400)
    topLeft = (150, 400)

    fontScale = 3
    fontColor = (255, 0, 0)
    lineType = 2

    emptyPoses = []

    while True:
        while len(data) < payload_size:
            data += conn.recv(4096)
        print('MESSAGESIZE')
        print(payload_size)
        packed_msg_size = data[:payload_size]
        data = data[payload_size:]
        msg_size = struct.unpack("=L", packed_msg_size)[0]  ### CHANGED
        print('unpack')
        # Retrieve all data based on message size
        while len(data) < msg_size:
            data += conn.recv(4096)
            print(len(data))
            print(msg_size)
        print('RECIEVED')

        frame_data = data[:msg_size]
        data = data[msg_size:]

        # Extract frame
        frame = pickle.loads(frame_data)

        #read_return_code, frame = vc.read()

        pose = run_demo(net, frame, 256, 0, 0, 1)

        if pose is not None:
            pose.draw(frame)
            #cv2.imshow('test', frame)
            if cv2.waitKey(1) == ord('q'):
                break

            A = np.array([pose.keypoints[8][0], pose.keypoints[8][1]])
            B = np.array([pose.keypoints[9][0], pose.keypoints[9][1]])
            C = np.array([pose.keypoints[10][0], pose.keypoints[10][1]])

            BA = A - B
            BC = C - B

            cosine_angle = np.dot(
                BA, BC) / (np.linalg.norm(BA) * np.linalg.norm(BC))
            angle = np.arccos(cosine_angle)
            angle = np.degrees(angle)
            #print(angle)

            if angle > 140:
                stepA = True
                if stdupAngle is 0:
                    stdupAngle = angle
                if angle > stdupAngle:
                    stdupAngle = angle
            if angle < 70:
                stepB = True
                if sitAngle is 0:
                    sitAngle = angle
                if angle < sitAngle:
                    sitAngle = angle

            if stepA and stepB is True:
                if angle > 140:
                    stepA = False
                    stepB = False
                    count += 1

                    if sitAngle > 60:
                        cv2.putText(frame, "Bend your knee more", topLeft,
                                    font, fontScale, fontColor, lineType)

                    stdupDiff = 140 - stdupAngle
                    sitDiff = 70 - sitAngle

                    stdupDiff = abs(stdupDiff)
                    sitDiff = abs(sitDiff)
                    correctness = 280 - (stdupDiff + sitDiff)

                    cv2.putText(frame, "correctness" + str(correctness),
                                bottomLeftCornerOfText, font, fontScale,
                                fontColor, lineType)

                    sitAngle = 0
                    stdupAngle = 0
            if stepA is True and stepB is False:
                if angle > 140:
                    stepA = False
                    stepB = False
                    sitAngle = 0
                    stdupAngle = 0

            cv2.putText(frame, "count" + str(count), bottomLeftCornerOfText,
                        font, fontScale, fontColor, lineType)
            #cv2.imshow("img", frame)

            print(count)
            encode_return_code, image_buffer = cv2.imencode('.jpg', frame)
            io_buf = io.BytesIO(image_buffer)
            yield (b'--frame\r\n'
                   b'Content-Type: image/jpeg\r\n\r\n' + io_buf.read() +
                   b'\r\n')
Exemple #25
0
                        default='data/2.jpg',
                        help='path to input image(s)')
    parser.add_argument('--cpu',
                        action='store_true',
                        help='run network inference on cpu')
    parser.add_argument('--track',
                        type=int,
                        default=1,
                        help='track pose id in video')
    parser.add_argument('--smooth',
                        type=int,
                        default=1,
                        help='smooth pose keypoints')
    args = parser.parse_args()

    if args.video == '' and args.images == '':
        raise ValueError('Either --video or --image has to be provided')

    net = PoseEstimationWithMobileNet()
    checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
    load_state(net, checkpoint)

    frame_provider = ImageReader(args.images)
    if args.video != '':
        frame_provider = VideoReader(args.video)
    else:
        args.track = 0

    run_demo(net, frame_provider, args.height_size, args.cpu, args.track,
             args.smooth)
def train(prepared_train_labels, train_images_folder, num_refinement_stages,
          base_lr, batch_size, batches_per_iter, num_workers, checkpoint_path,
          weights_only, from_mobilenet, checkpoints_folder, log_after,
          val_labels, val_images_folder, val_output_name, checkpoint_after,
          val_after):
    net = PoseEstimationWithMobileNet(num_refinement_stages)

    stride = 8
    sigma = 7
    path_thickness = 1
    dataset = CocoTrainDataset(prepared_train_labels,
                               train_images_folder,
                               stride,
                               sigma,
                               path_thickness,
                               transform=transforms.Compose([
                                   ConvertKeypoints(),
                                   Scale(),
                                   Rotate(pad=(128, 128, 128)),
                                   CropPad(pad=(128, 128, 128)),
                                   Flip()
                               ]))
    train_loader = DataLoader(dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers)

    optimizer = optim.Adam([
        {
            'params': get_parameters_conv(net.model, 'weight')
        },
        {
            'params': get_parameters_conv_depthwise(net.model, 'weight'),
            'weight_decay': 0
        },
        {
            'params': get_parameters_bn(net.model, 'weight'),
            'weight_decay': 0
        },
        {
            'params': get_parameters_bn(net.model, 'bias'),
            'lr': base_lr * 2,
            'weight_decay': 0
        },
        {
            'params': get_parameters_conv(net.cpm, 'weight'),
            'lr': base_lr
        },
        {
            'params': get_parameters_conv(net.cpm, 'bias'),
            'lr': base_lr * 2,
            'weight_decay': 0
        },
        {
            'params': get_parameters_conv_depthwise(net.cpm, 'weight'),
            'weight_decay': 0
        },
        {
            'params': get_parameters_conv(net.initial_stage, 'weight'),
            'lr': base_lr
        },
        {
            'params': get_parameters_conv(net.initial_stage, 'bias'),
            'lr': base_lr * 2,
            'weight_decay': 0
        },
        {
            'params': get_parameters_conv(net.refinement_stages, 'weight'),
            'lr': base_lr * 4
        },
        {
            'params': get_parameters_conv(net.refinement_stages, 'bias'),
            'lr': base_lr * 8,
            'weight_decay': 0
        },
        {
            'params': get_parameters_bn(net.refinement_stages, 'weight'),
            'weight_decay': 0
        },
        {
            'params': get_parameters_bn(net.refinement_stages, 'bias'),
            'lr': base_lr * 2,
            'weight_decay': 0
        },
    ],
                           lr=base_lr,
                           weight_decay=5e-4)

    num_iter = 0
    current_epoch = 0
    drop_after_epoch = [100, 200, 260]
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=drop_after_epoch,
                                               gamma=0.333)
    if checkpoint_path:
        checkpoint = torch.load(checkpoint_path)

        if from_mobilenet:
            load_from_mobilenet(net, checkpoint)
        else:
            load_state(net, checkpoint)
            if not weights_only:
                optimizer.load_state_dict(checkpoint['optimizer'])
                scheduler.load_state_dict(checkpoint['scheduler'])
                num_iter = checkpoint['iter']
                current_epoch = checkpoint['current_epoch']
                print("optimizer LR")
                for param_group in optimizer.param_groups:
                    print(param_group['lr'])

                for state in optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.cuda()

    net = DataParallel(net).cuda()
    net.train()

    from DGPT.Visualize.Viz import Viz
    viz = Viz(dict(env="refine"))

    for epochId in range(current_epoch, 280):
        # scheduler.step()
        total_losses = [0, 0] * (num_refinement_stages + 1
                                 )  # heatmaps loss, paf loss per stage
        batch_per_iter_idx = 0
        for batch_data in train_loader:
            if batch_per_iter_idx == 0:
                optimizer.zero_grad()

            images = batch_data['image'].cuda()
            keypoint_masks = batch_data['keypoint_mask'].cuda()
            paf_masks = batch_data['paf_mask'].cuda()
            keypoint_maps = batch_data['keypoint_maps'].cuda()
            paf_maps = batch_data['paf_maps'].cuda()

            images = preprocess(images)

            stages_output = net(images)

            losses = []
            for loss_idx in range(len(total_losses) // 2):
                losses.append(
                    l2_loss(stages_output[loss_idx * 2], keypoint_maps,
                            keypoint_masks, images.shape[0]))
                losses.append(
                    l2_loss(stages_output[loss_idx * 2 + 1], paf_maps,
                            paf_masks, images.shape[0]))
                total_losses[loss_idx *
                             2] += losses[-2].item() / batches_per_iter
                total_losses[loss_idx * 2 +
                             1] += losses[-1].item() / batches_per_iter

            loss = losses[0]
            for loss_idx in range(1, len(losses)):
                loss += losses[loss_idx]
            loss /= batches_per_iter
            loss.backward()

            viz.draw_line(num_iter, loss.item(), "Loss")

            batch_per_iter_idx += 1
            if batch_per_iter_idx == batches_per_iter:
                optimizer.step()
                batch_per_iter_idx = 0
                num_iter += 1
                scheduler.step()
            else:
                continue

            if num_iter % log_after == 0:
                print('Iter: {}'.format(num_iter))
                for loss_idx in range(len(total_losses) // 2):
                    print('\n'.join([
                        'stage{}_pafs_loss:     {}',
                        'stage{}_heatmaps_loss: {}'
                    ]).format(loss_idx + 1,
                              total_losses[loss_idx * 2 + 1] / log_after,
                              loss_idx + 1,
                              total_losses[loss_idx * 2] / log_after))
                for loss_idx in range(len(total_losses)):
                    total_losses[loss_idx] = 0

                xx = images[:1, ...].detach()  #.clone()
                hh = keypoint_maps[:1, ...].detach()  #.clone()
                mm = keypoint_masks[:1, ...].detach()  #.clone()

                print(xx.shape, hh.shape, mm.shape)

                hh = hh.squeeze(0).reshape(19, 1, hh.shape[2], hh.shape[3])
                mm = mm.squeeze(0).reshape(19, 1, hh.shape[2], hh.shape[3])

                viz.draw_images(xx, "input1")
                viz.draw_images(hh, "input1_heatmap")
                viz.draw_images(mm, "input1_mask")

                oh = stages_output[-2].detach()[:1, :-1, ...]
                oh = oh.reshape(oh.shape[1], 1, oh.shape[2], oh.shape[3])
                viz.draw_images(oh, "output1_heatmap")

            if num_iter % checkpoint_after == 0:
                snapshot_name = '{}/checkpoint_iter_{}.pth'.format(
                    checkpoints_folder, num_iter)
                torch.save(
                    {
                        'state_dict': net.module.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'iter': num_iter,
                        'current_epoch': epochId
                    }, snapshot_name)
            if num_iter % val_after == 0:
                print('Validation...')
                evaluate(val_labels, val_output_name, val_images_folder, net)
                net.train()