Exemplo n.º 1
0
def main(argv):
    tf_device = '/cpu:0'
    with tf.device(tf_device):
        """Build graph
        """
        if FLAGS.color_channel == 'RGB':
            input_data = tf.placeholder(dtype=tf.float32,
                                        shape=[None, FLAGS.input_size, FLAGS.input_size, 3],
                                        name='input_image')
        else:
            input_data = tf.placeholder(dtype=tf.float32,
                                        shape=[None, FLAGS.input_size, FLAGS.input_size, 1],
                                        name='input_image')

        center_map = tf.placeholder(dtype=tf.float32,
                                    shape=[None, FLAGS.input_size, FLAGS.input_size, 1],
                                    name='center_map')

        model = cpm_body_slim.CPM_Model(FLAGS.stages, FLAGS.joints + 1)
        model.build_model(input_data, center_map, 1)

    saver = tf.train.Saver()

    """Create session and restore weights
    """
    sess = tf.Session()

    sess.run(tf.global_variables_initializer())
    if FLAGS.model_path.endswith('pkl'):
        model.load_weights_from_file(FLAGS.model_path, sess, False)
    else:
        saver.restore(sess, FLAGS.model_path)

    test_center_map = cpm_utils.gaussian_img(FLAGS.input_size,
                                             FLAGS.input_size,
                                             FLAGS.input_size / 2,
                                             FLAGS.input_size / 2,
                                             FLAGS.cmap_radius)
    test_center_map = np.reshape(test_center_map, [1, FLAGS.input_size,
                                                   FLAGS.input_size, 1])

    # Check weights
    for variable in tf.trainable_variables():
        with tf.variable_scope('', reuse=True):
            var = tf.get_variable(variable.name.split(':0')[0])
            print(variable.name, np.mean(sess.run(var)))

    # Create kalman filters
    if FLAGS.KALMAN_ON:
        kalman_filter_array = [cv2.KalmanFilter(4, 2) for _ in range(FLAGS.joints)]
        for _, joint_kalman_filter in enumerate(kalman_filter_array):
            joint_kalman_filter.transitionMatrix = np.array([[1, 0, 1, 0],
                                                             [0, 1, 0, 1],
                                                             [0, 0, 1, 0],
                                                             [0, 0, 0, 1]],
                                                            np.float32)
            joint_kalman_filter.measurementMatrix = np.array([[1, 0, 0, 0],
                                                              [0, 1, 0, 0]],
                                                             np.float32)
            joint_kalman_filter.processNoiseCov = np.array([[1, 0, 0, 0],
                                                            [0, 1, 0, 0],
                                                            [0, 0, 1, 0],
                                                            [0, 0, 0, 1]],
                                                           np.float32) * FLAGS.kalman_noise
    else:
        kalman_filter_array = None

    # read in video / flow frames
    if FLAGS.DEMO_TYPE.endswith(('avi', 'flv', 'mp4')):
        # OpenCV can only read in '.avi' files
        cam = imageio.get_reader(FLAGS.DEMO_TYPE)
    else:
        cam = cv2.VideoCapture(FLAGS.cam_num)

    # iamge processing
    with tf.device(tf_device):
        if FLAGS.DEMO_TYPE.endswith(('avi', 'flv', 'mp4')):
            ori_fps = cam.get_meta_data()['fps']
            print('This video fps is %f' % ori_fps)
            video_length = cam.get_length()
            writer_path = os.path.join('results', os.path.basename(FLAGS.DEMO_TYPE))
            # !! OpenCV can only write in .avi
            cv_writer = cv2.VideoWriter(writer_path + '.avi',
                                        # cv2.cv.CV_FOURCC('M', 'J', 'P', 'G'),
                                        cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
                                        ori_fps,
                                        (FLAGS.input_size, FLAGS.input_size))
            # imageio_writer = imageio.get_writer(writer_path, fps=ori_fps)

            try:
                for it, im in enumerate(cam):
                    test_img_t = time.time()

                    test_img = cpm_utils.read_image(im, [], FLAGS.input_size, 'VIDEO')
                    test_img_resize = cv2.resize(test_img, (FLAGS.input_size, FLAGS.input_size))
                    print('img read time %f' % (time.time() - test_img_t))

                    if FLAGS.color_channel == 'GRAY':
                        test_img_resize = mgray(test_img_resize, test_img)

                    test_img_input = test_img_resize / 256.0 - 0.5
                    test_img_input = np.expand_dims(test_img_input, axis=0)

                    # Inference
                    fps_t = time.time()
                    predict_heatmap, stage_heatmap_np = sess.run([model.current_heatmap,
                                                                  model.stage_heatmap,
                                                                  ],
                                                                 feed_dict={'input_image:0': test_img_input,
                                                                            'center_map:0': test_center_map})

                    # Show visualized image
                    demo_img = visualize_result(test_img, FLAGS, stage_heatmap_np, kalman_filter_array)
                    cv2.imshow('demo_img', demo_img.astype(np.uint8))
                    if (cv2.waitKey(1) == ord('q')): break
                    print('fps: %.2f' % (1 / (time.time() - fps_t)))

                    cv_writer.write(demo_img.astype(np.uint8))
                    # imageio_writer.append_data(demo_img[:, :, 1])
            except KeyboardInterrupt:
                print('Stopped! {}/{} frames captured!'.format(it, video_length))
            finally:
                cv_writer.release()
                # imageio_writer.close()
        else:
            while True:
                test_img_t = time.time()

                if FLAGS.DEMO_TYPE.endswith(('png', 'jpg')):
                    test_img = cpm_utils.read_image(FLAGS.DEMO_TYPE, [], FLAGS.input_size, 'IMAGE')
                else:
                    test_img = cpm_utils.read_image([], cam, FLAGS.input_size, 'WEBCAM')

                test_img_resize = cv2.resize(test_img, (FLAGS.input_size, FLAGS.input_size))
                print('img read time %f' % (time.time() - test_img_t))

                if FLAGS.color_channel == 'GRAY':
                    test_img_resize = mgray(test_img_resize, test_img)

                test_img_input = test_img_resize / 256.0 - 0.5
                test_img_input = np.expand_dims(test_img_input, axis=0)

                if FLAGS.DEMO_TYPE.endswith(('png', 'jpg')):
                    # Inference
                    fps_t = time.time()
                    predict_heatmap, stage_heatmap_np = sess.run([model.current_heatmap,
                                                                  model.stage_heatmap, ],
                                                                 feed_dict={'input_image:0': test_img_input,
                                                                            'center_map:0': test_center_map})

                    # Show visualized image
                    demo_img = visualize_result(test_img, FLAGS, stage_heatmap_np, kalman_filter_array)
                    cv2.imshow('demo_img', demo_img.astype(np.uint8))
                    if cv2.waitKey(0) == ord('q'): break
                    print('fps: %.2f' % (1 / (time.time() - fps_t)))

                elif FLAGS.DEMO_TYPE == 'MULTI':

                    # Inference
                    fps_t = time.time()
                    predict_heatmap, stage_heatmap_np = sess.run([model.current_heatmap,
                                                                  model.stage_heatmap,
                                                                  ],
                                                                 feed_dict={'input_image:0': test_img_input,
                                                                            'center_map:0': test_center_map})

                    # Show visualized image
                    demo_img = visualize_result(test_img, FLAGS, stage_heatmap_np, kalman_filter_array)
                    cv2.imshow('demo_img', demo_img.astype(np.uint8))
                    if cv2.waitKey(1) == ord('q'): break
                    print('fps: %.2f' % (1 / (time.time() - fps_t)))


                elif FLAGS.DEMO_TYPE == 'SINGLE':

                    # Inference
                    fps_t = time.time()
                    stage_heatmap_np = sess.run([model.stage_heatmap[5]],
                                                feed_dict={'input_image:0': test_img_input,
                                                           'center_map:0': test_center_map})

                    # Show visualized image
                    demo_img = visualize_result(test_img, FLAGS, stage_heatmap_np, kalman_filter_array)
                    cv2.imshow('current heatmap', (demo_img).astype(np.uint8))
                    if cv2.waitKey(1) == ord('q'): break
                    print('fps: %.2f' % (1 / (time.time() - fps_t)))


                elif FLAGS.DEMO_TYPE == 'HM':

                    # Inference
                    fps_t = time.time()
                    stage_heatmap_np = sess.run([model.stage_heatmap[FLAGS.stages - 1]],
                                                feed_dict={'input_image:0': test_img_input,
                                                           'center_map:0': test_center_map})
                    print('fps: %.2f' % (1 / (time.time() - fps_t)))

                    # demo_stage_heatmap = stage_heatmap_np[len(stage_heatmap_np) - 1][0, :, :, 0:FLAGS.joints].reshape(
                    #     (FLAGS.hmap_size, FLAGS.hmap_size, FLAGS.joints))
                    demo_stage_heatmap = stage_heatmap_np[-1][0, :, :, 0:FLAGS.joints].reshape(
                        (FLAGS.hmap_size, FLAGS.hmap_size, FLAGS.joints))
                    demo_stage_heatmap = cv2.resize(demo_stage_heatmap, (FLAGS.input_size, FLAGS.input_size))

                    vertical_imgs = []
                    tmp_img = None
                    joint_coord_set = np.zeros((FLAGS.joints, 2))

                    for joint_num in range(FLAGS.joints):
                        # Concat until 4 img
                        if (joint_num % 4) == 0 and joint_num != 0:
                            vertical_imgs.append(tmp_img)
                            tmp_img = None

                        demo_stage_heatmap[:, :, joint_num] *= (255 / np.max(demo_stage_heatmap[:, :, joint_num]))

                        # Plot color joints
                        if np.min(demo_stage_heatmap[:, :, joint_num]) > -50:
                            joint_coord = np.unravel_index(np.argmax(demo_stage_heatmap[:, :, joint_num]),
                                                           (FLAGS.input_size, FLAGS.input_size))
                            joint_coord_set[joint_num, :] = joint_coord
                            color_code_num = (joint_num // 4)

                            if joint_num in [0, 4, 8, 12, 16]:
                                if PYTHON_VERSION == 3:
                                    joint_color = list(
                                        map(lambda x: x + 35 * (joint_num % 4), joint_color_code[color_code_num]))
                                else:
                                    joint_color = map(lambda x: x + 35 * (joint_num % 4), joint_color_code[color_code_num])

                                cv2.circle(test_img, center=(joint_coord[1], joint_coord[0]), radius=3, color=joint_color,
                                           thickness=-1)
                            else:
                                if PYTHON_VERSION == 3:
                                    joint_color = list(
                                        map(lambda x: x + 35 * (joint_num % 4), joint_color_code[color_code_num]))
                                else:
                                    joint_color = map(lambda x: x + 35 * (joint_num % 4), joint_color_code[color_code_num])

                                cv2.circle(test_img, center=(joint_coord[1], joint_coord[0]), radius=3, color=joint_color,
                                           thickness=-1)

                        # Put text
                        tmp = demo_stage_heatmap[:, :, joint_num].astype(np.uint8)
                        tmp = cv2.putText(tmp, 'Min:' + str(np.min(demo_stage_heatmap[:, :, joint_num])),
                                          org=(5, 20), fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=0.3, color=150)
                        tmp = cv2.putText(tmp, 'Mean:' + str(np.mean(demo_stage_heatmap[:, :, joint_num])),
                                          org=(5, 30), fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=0.3, color=150)
                        tmp_img = np.concatenate((tmp_img, tmp), axis=0) \
                            if tmp_img is not None else tmp

                    # Plot limbs
                    for limb_num in range(len(limbs)):
                        if np.min(demo_stage_heatmap[:, :, limbs[limb_num][0]]) > -2000 and np.min(
                                demo_stage_heatmap[:, :, limbs[limb_num][1]]) > -2000:
                            x1 = joint_coord_set[limbs[limb_num][0], 0]
                            y1 = joint_coord_set[limbs[limb_num][0], 1]
                            x2 = joint_coord_set[limbs[limb_num][1], 0]
                            y2 = joint_coord_set[limbs[limb_num][1], 1]
                            length = ((x1 - x2) ** 2 + (y1 - y2) ** 2) ** 0.5
                            if length < 10000 and length > 5:
                                deg = math.degrees(math.atan2(x1 - x2, y1 - y2))
                                polygon = cv2.ellipse2Poly((int((y1 + y2) / 2), int((x1 + x2) / 2)),
                                                           (int(length / 2), 3),
                                                           int(deg),
                                                           0, 360, 1)
                                color_code_num = limb_num // 4
                                if PYTHON_VERSION == 3:
                                    limb_color = list(
                                        map(lambda x: x + 35 * (limb_num % 4), joint_color_code[color_code_num]))
                                else:
                                    limb_color = map(lambda x: x + 35 * (limb_num % 4), joint_color_code[color_code_num])

                                cv2.fillConvexPoly(test_img, polygon, color=limb_color)

                    if tmp_img is not None:
                        tmp_img = np.lib.pad(tmp_img, ((0, vertical_imgs[0].shape[0] - tmp_img.shape[0]), (0, 0)),
                                             'constant', constant_values=(0, 0))
                        vertical_imgs.append(tmp_img)

                    # Concat horizontally
                    output_img = None
                    for col in range(len(vertical_imgs)):
                        output_img = np.concatenate((output_img, vertical_imgs[col]), axis=1) if output_img is not None else \
                            vertical_imgs[col]

                    output_img = output_img.astype(np.uint8)
                    output_img = cv2.applyColorMap(output_img, cv2.COLORMAP_JET)
                    test_img = cv2.resize(test_img, (300, 300), cv2.INTER_LANCZOS4)
                    cv2.imshow('hm', output_img)
                    cv2.moveWindow('hm', 2000, 200)
                    cv2.imshow('rgb', test_img)
                    cv2.moveWindow('rgb', 2000, 750)
                    if cv2.waitKey(1) == ord('q'): break
Exemplo n.º 2
0
def main(argv):
    tf_device = '/gpu:0'
    with tf.device(tf_device):
        """Build graph
        """
        if FLAGS.color_channel == 'RGB':
            input_data = tf.placeholder(
                dtype=tf.float32,
                shape=[None, FLAGS.input_size, FLAGS.input_size, 3],
                name='input_image')
        else:
            input_data = tf.placeholder(
                dtype=tf.float32,
                shape=[None, FLAGS.input_size, FLAGS.input_size, 1],
                name='input_image')

        center_map = tf.placeholder(
            dtype=tf.float32,
            shape=[None, FLAGS.input_size, FLAGS.input_size, 1],
            name='center_map')

        model = cpm_body_slim.CPM_Model(FLAGS.stages, FLAGS.joints + 1)
        model.build_model(input_data, center_map, 1)

    saver = tf.train.Saver()
    """Create session and restore weights
    """
    sess = tf.Session()

    sess.run(tf.global_variables_initializer())
    if FLAGS.model_path.endswith('pkl'):
        model.load_weights_from_file(FLAGS.model_path, sess, False)
    else:
        saver.restore(sess, FLAGS.model_path)

    test_center_map = cpm_utils.gaussian_img(FLAGS.input_size,
                                             FLAGS.input_size,
                                             FLAGS.input_size / 2,
                                             FLAGS.input_size / 2,
                                             FLAGS.cmap_radius)
    test_center_map = np.reshape(test_center_map,
                                 [1, FLAGS.input_size, FLAGS.input_size, 1])

    # Check weights
    for variable in tf.trainable_variables():
        with tf.variable_scope('', reuse=True):
            var = tf.get_variable(variable.name.split(':0')[0])
            print(variable.name, np.mean(sess.run(var)))

    # Create kalman filters
    if FLAGS.KALMAN_ON:
        kalman_filter_array = [
            cv2.KalmanFilter(4, 2) for _ in range(FLAGS.joints)
        ]
        for _, joint_kalman_filter in enumerate(kalman_filter_array):
            joint_kalman_filter.transitionMatrix = np.array(
                [[1, 0, 1, 0], [0, 1, 0, 1], [0, 0, 1, 0], [0, 0, 0, 1]],
                np.float32)
            joint_kalman_filter.measurementMatrix = np.array(
                [[1, 0, 0, 0], [0, 1, 0, 0]], np.float32)
            joint_kalman_filter.processNoiseCov = np.array(
                [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]],
                np.float32) * FLAGS.kalman_noise
    else:
        kalman_filter_array = None

    # read in video / flow frames
    if FLAGS.DEMO_TYPE.endswith(('avi', 'flv', 'mp4')):
        # OpenCV can only read in '.avi' files
        cam = imageio.get_reader(FLAGS.DEMO_TYPE)
    else:
        cam = cv2.VideoCapture(FLAGS.cam_num)

    # iamge processing
    with tf.device(tf_device):
        test_img_t = time.time()

        if FLAGS.DEMO_TYPE.endswith(('png', 'jpg')):
            test_img = cpm_utils.read_image(FLAGS.DEMO_TYPE, [],
                                            FLAGS.input_size, 'IMAGE')
        else:
            test_img = cpm_utils.read_image([], cam, FLAGS.input_size,
                                            'WEBCAM')

        test_img_resize = cv2.resize(test_img,
                                     (FLAGS.input_size, FLAGS.input_size))
        print('img read time %f' % (time.time() - test_img_t))

        if FLAGS.color_channel == 'GRAY':
            test_img_resize = mgray(test_img_resize, test_img)

        test_img_input = test_img_resize / 256.0 - 0.5
        test_img_input = np.expand_dims(test_img_input, axis=0)

        if FLAGS.DEMO_TYPE.endswith(('png', 'jpg')):
            # Inference
            fps_t = time.time()
            predict_heatmap, stage_heatmap_np = sess.run([
                model.current_heatmap,
                model.stage_heatmap,
            ],
                                                         feed_dict={
                                                             'input_image:0':
                                                             test_img_input,
                                                             'center_map:0':
                                                             test_center_map
                                                         })

            # Show visualized image
            demo_img = visualize_result(test_img, FLAGS, stage_heatmap_np,
                                        kalman_filter_array)
            cv2.imshow('demo_img', demo_img.astype(np.uint8))
            if cv2.waitKey(0) == ord('q'): exit()
            print('fps: %.2f' % (1 /
                                 (time.time() - fps_t)))  # iamge processing
Exemplo n.º 3
0
Arquivo: train.py Projeto: sjkai/CPM
def main(argv):
    """Build graph
    """
    batch_x, batch_c, batch_y, batch_x_orig = tf_utils.read_batch_cpm(
        FLAGS.tfr_data_files, FLAGS.input_size, FLAGS.heatmap_size,
        FLAGS.num_of_joints, FLAGS.center_radius, FLAGS.batch_size)
    if FLAGS.color_channel == 'RGB':
        input_placeholder = tf.placeholder(dtype=tf.float32,
                                           shape=(FLAGS.batch_size,
                                                  FLAGS.input_size,
                                                  FLAGS.input_size, 3),
                                           name='input_placeholer')
    elif FLAGS.color_channel == 'GRAY':
        input_placeholder = tf.placeholder(dtype=tf.float32,
                                           shape=(FLAGS.batch_size,
                                                  FLAGS.input_size,
                                                  FLAGS.input_size, 1),
                                           name='input_placeholer')
    cmap_placeholder = tf.placeholder(dtype=tf.float32,
                                      shape=(FLAGS.batch_size,
                                             FLAGS.input_size,
                                             FLAGS.input_size, 1),
                                      name='cmap_placeholder')
    hmap_placeholder = tf.placeholder(
        dtype=tf.float32,
        shape=(FLAGS.batch_size, FLAGS.heatmap_size, FLAGS.heatmap_size,
               FLAGS.num_of_joints + 1),
        name='hmap_placeholder')

    model = cpm_body_slim.CPM_Model(FLAGS.stages, FLAGS.num_of_joints + 1)
    model.build_model(input_placeholder, cmap_placeholder, FLAGS.batch_size)
    model.build_loss(hmap_placeholder, FLAGS.lr, FLAGS.lr_decay_rate,
                     FLAGS.lr_decay_step)
    print('=====Model Build=====\n')
    """Training
    """
    with tf.Session() as sess:

        # Create dataset queue
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        ## Create summary
        tf_writer = tf.summary.FileWriter(FLAGS.log_dir,
                                          sess.graph,
                                          filename_suffix=FLAGS.log_file_name)

        ## Create model saver
        saver = tf.train.Saver(max_to_keep=None)

        # Init
        init = tf.global_variables_initializer()
        sess.run(init)

        # Restore weights
        if FLAGS.pretrained_model is not None:
            if FLAGS.pretrained_model.endswith('.pkl'):
                model.load_weights_from_file(FLAGS.pretrained_model,
                                             sess,
                                             finetune=True)

                # Check weights
                for variable in tf.trainable_variables():
                    with tf.variable_scope('', reuse=True):
                        var = tf.get_variable(variable.name.split(':0')[0])
                        print(variable.name, np.mean(sess.run(var)))

            else:
                saver.restore(sess, FLAGS.pretrained_model)

                # check weights
                for variable in tf.trainable_variables():
                    with tf.variable_scope('', reuse=True):
                        var = tf.get_variable(variable.name.split(':0')[0])
                        print(variable.name, np.mean(sess.run(var)))

        while True:

            # Read in batch data
            batch_x_np, batch_y_np, batch_c_np = sess.run(
                [batch_x, batch_y, batch_c])

            # Warp training images
            for img_num in range(batch_x_np.shape[0]):
                deg1 = (2 * np.random.rand() - 1) * 50
                deg2 = (2 * np.random.rand() - 1) * 50
                batch_x_np[img_num,
                           ...] = cpm_utils.warpImage(batch_x_np[img_num, ...],
                                                      0, deg1, deg2, 1, 30)
                batch_y_np[img_num,
                           ...] = cpm_utils.warpImage(batch_y_np[img_num, ...],
                                                      0, deg1, deg2, 1, 30)
                batch_y_np[img_num, :, :, FLAGS.num_of_joints] = np.ones(shape=(FLAGS.input_size, FLAGS.input_size)) - \
                                                                 np.max(
                                                                     batch_y_np[img_num, :, :, 0:FLAGS.num_of_joints],
                                                                     axis=2)
                batch_c_np[img_num, ...] = cpm_utils.warpImage(
                    batch_c_np[img_num, ...], 0, deg1, deg2, 1, 30).reshape(
                        (FLAGS.input_size, FLAGS.input_size, 1))

            # Convert image to grayscale
            if FLAGS.color_channel == 'GRAY':
                batch_x_gray_np = np.zeros(
                    (batch_x_np.shape[0], FLAGS.input_size, FLAGS.input_size,
                     1))
                for img_num in range(batch_x_np.shape[0]):
                    tmp = batch_x_np[img_num, ...]
                    tmp += 0.5
                    tmp *= 255
                    tmp = np.dot(tmp[..., :3], [0.114, 0.587, 0.299])
                    tmp /= 255
                    tmp -= 0.5
                    batch_x_gray_np[img_num, ...] = tmp.reshape(
                        (FLAGS.input_size, FLAGS.input_size, 1))
                batch_x_np = batch_x_gray_np

            # Recreate heatmaps
            gt_heatmap_np = cpm_utils.make_gaussian_batch(
                batch_y_np, FLAGS.heatmap_size, 3)

            # Update once
            stage_losses_np, total_loss_np, _, summary, current_lr, \
            stage_heatmap_np, global_step = sess.run([model.stage_loss,
                                                      model.total_loss,
                                                      model.train_op,
                                                      model.merged_summary,
                                                      model.lr,
                                                      model.stage_heatmap,
                                                      model.global_step
                                                      ],
                                                     feed_dict={input_placeholder: batch_x_np,
                                                                cmap_placeholder: batch_c_np,
                                                                hmap_placeholder: gt_heatmap_np})

            # Write logs
            tf_writer.add_summary(summary, global_step)

            # Draw intermediate results
            if global_step % 50 == 0:
                # if True:

                if FLAGS.color_channel == 'GRAY':
                    demo_img = np.repeat(batch_x_np[0], 3, axis=2)
                    demo_img += 0.5
                elif FLAGS.color_channel == 'RGB':
                    demo_img = batch_x_np[0] + 0.5
                demo_stage_heatmaps = []
                for stage in range(FLAGS.stages):
                    demo_stage_heatmap = stage_heatmap_np[stage][
                        0, :, :, 0:FLAGS.num_of_joints].reshape(
                            (FLAGS.heatmap_size, FLAGS.heatmap_size,
                             FLAGS.num_of_joints))
                    demo_stage_heatmap = cv2.resize(
                        demo_stage_heatmap,
                        (FLAGS.input_size, FLAGS.input_size))
                    demo_stage_heatmap = np.amax(demo_stage_heatmap, axis=2)
                    demo_stage_heatmap = np.reshape(
                        demo_stage_heatmap,
                        (FLAGS.input_size, FLAGS.input_size, 1))
                    demo_stage_heatmap = np.repeat(demo_stage_heatmap,
                                                   3,
                                                   axis=2)
                    demo_stage_heatmaps.append(demo_stage_heatmap)

                demo_gt_heatmap = gt_heatmap_np[0, :, :,
                                                0:FLAGS.num_of_joints].reshape(
                                                    (FLAGS.heatmap_size,
                                                     FLAGS.heatmap_size,
                                                     FLAGS.num_of_joints))
                demo_gt_heatmap = cv2.resize(
                    demo_gt_heatmap, (FLAGS.input_size, FLAGS.input_size))
                demo_gt_heatmap = np.amax(demo_gt_heatmap, axis=2)
                demo_gt_heatmap = np.reshape(
                    demo_gt_heatmap, (FLAGS.input_size, FLAGS.input_size, 1))
                demo_gt_heatmap = np.repeat(demo_gt_heatmap, 3, axis=2)

                if FLAGS.stages > 4:
                    upper_img = np.concatenate(
                        (demo_stage_heatmaps[0], demo_stage_heatmaps[1],
                         demo_stage_heatmaps[2]),
                        axis=1)
                    blend_img = 0.5 * demo_gt_heatmap + 0.5 * demo_img
                    lower_img = np.concatenate(
                        (demo_stage_heatmaps[FLAGS.stages - 1],
                         demo_gt_heatmap, blend_img),
                        axis=1)
                    demo_img = np.concatenate((upper_img, lower_img), axis=0)
                    cv2.imshow('current heatmap',
                               (demo_img * 255).astype(np.uint8))
                    cv2.waitKey()
                else:
                    upper_img = np.concatenate(
                        (demo_stage_heatmaps[FLAGS.stages - 1],
                         demo_gt_heatmap, demo_img),
                        axis=1)
                    cv2.imshow('current heatmap',
                               (upper_img * 255).astype(np.uint8))
                    cv2.waitKey()

            print('##========Iter {:>6d}========##'.format(global_step))
            print('Current learning rate: {:.8f}'.format(current_lr))
            for stage_num in range(FLAGS.stages):
                print('Stage {} loss: {:>.3f}'.format(
                    stage_num + 1, stage_losses_np[stage_num]))
            print('Total loss: {:>.3f}\n\n'.format(total_loss_np))

            # Save models
            if global_step % 5000 == 1:
                save_path_str = 'models/' + FLAGS.saved_model_name
                saver.save(sess=sess,
                           save_path=save_path_str,
                           global_step=global_step)
                print('\nModel checkpoint saved...\n')

            # Finish training
            if global_step == FLAGS.training_iterations:
                break

        coord.request_stop()
        coord.join(threads)

    print('Training done.')
Exemplo n.º 4
0
def main(argv):
    tf_device = '/gpu:0'
    with tf.device(tf_device):
        
        """
        if FLAGS.color_channel == 'RGB':
            input_data = tf.placeholder(dtype=tf.float32,
                                        shape=[None, FLAGS.input_size, FLAGS.input_size, 3],
                                        name='input_image')
        else:
            input_data = tf.placeholder(dtype=tf.float32,
                                        shape=[None, FLAGS.input_size, FLAGS.input_size, 1],
                                        name='input_image')

        center_map = tf.placeholder(dtype=tf.float32,
                                    shape=[None, FLAGS.input_size, FLAGS.input_size, 1],
                                    name='center_map')
        """
        # model = cpm_body_slim.CPM_Model(FLAGS.stages, FLAGS.joints + 1)
        # model.build_model(input_data, center_map, 1)
        # 没有背景
        model = cpm_body_slim.CPM_Model(FLAGS.input_size, FLAGS.hmap_size, FLAGS.stages, FLAGS.joints, 1, img_type = FLAGS.color_channel)

    saver = tf.train.Saver()

    """Create session and restore weights
    """
    sess = tf.Session()

    sess.run(tf.global_variables_initializer())
    if FLAGS.model_path.endswith('pkl'):
        model.load_weights_from_file(FLAGS.model_path, sess, False)
    else:
        saver.restore(sess, FLAGS.model_path)

    # 不要centermap, 意义不明
    """
    # 创建center_map,正方形的中心放置高斯响应
    test_center_map = cpm_utils.gaussian_img(FLAGS.input_size,
                                             FLAGS.input_size,
                                             FLAGS.input_size / 2,
                                             FLAGS.input_size / 2,
                                             FLAGS.cmap_radius)
    test_center_map = np.reshape(test_center_map, [1, FLAGS.input_size,
                                                   FLAGS.input_size, 1])
    """

    # read in video / flow frames
    if FLAGS.DEMO_TYPE.endswith(('avi', 'flv', 'mp4')):
        # OpenCV can only read in '.avi' files
        cam = imageio.get_reader(FLAGS.DEMO_TYPE)

    # iamge processing
    with tf.device(tf_device):
        if FLAGS.DEMO_TYPE.endswith(('avi', 'flv', 'mp4')):
            ori_fps = cam.get_meta_data()['fps']

            cap = cv2.VideoCapture(FLAGS.DEMO_TYPE)
            total_frame = cap.get(cv2.CAP_PROP_FRAME_COUNT)
            (W,H) = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
            pbar = tqdm(total=total_frame)

            print('This video fps is %f' % ori_fps)
            video_length = cam.get_length()
            # writer_path = os.path.join('results', os.path.basename(FLAGS.DEMO_TYPE))
            # !! OpenCV can only write in .avi
            cv_writer = cv2.VideoWriter('results/result.mp4',
                                        # cv2.cv.CV_FOURCC('M', 'J', 'P', 'G'),
                                        cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'),
                                        ori_fps,
                                        (W,H))
            # imageio_writer = imageio.get_writer(writer_path, fps=ori_fps)

            try:
                for it, im in enumerate(cam):
                    test_img_t = time.time()

                    full_img = im.copy()    # 把结果画在原图上
                    # test_img是原图像resize得到,并且进行了padding
                    test_img = cpm_utils.read_image(im, [], FLAGS.input_size, 'VIDEO')
                    # 多余的resize
                    # test_img_resize = cv2.resize(test_img, (FLAGS.input_size, FLAGS.input_size))
                    # print('img read time %f' % (time.time() - test_img_t))

                    if FLAGS.color_channel == 'GRAY':
                        test_img  = cv2.cvtColor(test_img, cv2.COLOR_RGB2GRAY)

                    test_img_input = test_img / 256.0 - 0.5
                    test_img_input = np.expand_dims(test_img_input, axis=0)

                    # Inference
                    # fps_t = time.time()
                    """
                    predict_heatmap, stage_heatmap_np = sess.run([model.current_heatmap,
                                                                  model.stage_heatmap,
                                                                  ],
                                                                 feed_dict={'input_placeholder:0': test_img_input,
                                                                            'cmap_placeholder:0': test_center_map})
                    """
                    predict_heatmap, stage_heatmap_np = sess.run([model.current_heatmap,
                                                                  model.stage_heatmap,
                                                                  ],
                                                                 feed_dict={'input_placeholder:0': test_img_input})

                    # Show visualized image
                    demo_img = visualize_result(test_img, full_img, FLAGS, stage_heatmap_np)

                    #cv2.putText(demo_img, "FPS: %.1f" % (1 / (time.time() - fps_t)), (20, 20),  cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 1)
                    cv_writer.write(demo_img.astype(np.uint8))
                    pbar.update(1)
                    # print(1/(time.time()-test_img_t))
                    cv2.imwrite('results/pics_test/{}.png'.format(str(it).zfill(5)), demo_img[:,:,::-1])
                    #exit(0)
                    # imageio_writer.append_data(demo_img[:, :, 1])
            except KeyboardInterrupt:
                print('Stopped! {}/{} frames captured!'.format(it, video_length))
            finally:
                cv_writer.release()
                # imageio_writer.close()

        elif FLAGS.DEMO_TYPE.endswith(('png', 'jpg')):
            test_img = cpm_utils.read_image(FLAGS.DEMO_TYPE, [], FLAGS.input_size, 'IMAGE')
            
            test_img_input = test_img / 256.0 - 0.5
            test_img_input = np.expand_dims(test_img_input, axis=0)

            # Inference
            fps_t = time.time()
            """
            stage_heatmap_np = sess.run([model.stage_heatmap[5]],
                                    feed_dict={'input_placeholder:0': test_img_input,
                                               'cmap_placeholder:0': test_center_map})
            """
            stage_heatmap_np = sess.run([model.stage_heatmap[5]],
                                    feed_dict={'input_placeholder:0': test_img_input})

            # Show visualized image
            ori_img = cv2.imread(FLAGS.DEMO_TYPE)
            demo_img = visualize_result(test_img, ori_img, FLAGS, stage_heatmap_np)
            cv2.imwrite('results/test.jpg', demo_img)
            print('fps: %.1f' % (1 / (time.time() - fps_t)))
        
        else:
            print('Demo type is not defined, please check it!')