def main(argv):
    """

    :param argv:
    :return:
    """
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    """ Create dirs for saving models and logs
    """
    model_path_suffix = os.path.join(FLAGS.network_def,
                                     'input_{}_output_{}'.format(FLAGS.input_size, FLAGS.heatmap_size),
                                     'joints_{}'.format(FLAGS.num_of_joints),
                                     'stages_{}'.format(FLAGS.cpm_stages),
                                     'init_{}_rate_{}_step_{}'.format(FLAGS.init_lr, FLAGS.lr_decay_rate,
                                                                      FLAGS.lr_decay_step)
                                     )
    model_save_dir = os.path.join('models',
                                  'weights',
                                  model_path_suffix)
    train_log_save_dir = os.path.join('models',
                                      'logs',
                                      model_path_suffix,
                                      'train')
    test_log_save_dir = os.path.join('models',
                                     'logs',
                                     model_path_suffix,
                                     'test')
    os.system('mkdir -p {}'.format(model_save_dir))
    os.system('mkdir -p {}'.format(train_log_save_dir))
    os.system('mkdir -p {}'.format(test_log_save_dir))

    """ Create data generator
    """
    g = Ensemble_data_generator.ensemble_data_generator(FLAGS.train_img_dir,
                                                        FLAGS.bg_img_dir,
                                                        FLAGS.batch_size, FLAGS.input_size, True, True,
                                                        FLAGS.augmentation_config, FLAGS.hnm, FLAGS.do_cropping)
    g_eval = Ensemble_data_generator.ensemble_data_generator(FLAGS.val_img_dir,
                                                             FLAGS.bg_img_dir,
                                                             FLAGS.batch_size, FLAGS.input_size, True, True,
                                                             FLAGS.augmentation_config, FLAGS.hnm, FLAGS.do_cropping)

    """ Build network graph
    """
    model = cpm_model.CPM_Model(input_size=FLAGS.input_size,
                                heatmap_size=FLAGS.heatmap_size,
                                stages=FLAGS.cpm_stages,
                                joints=FLAGS.num_of_joints,
                                img_type=FLAGS.color_channel,
                                is_training=True)
    model.build_loss(FLAGS.init_lr, FLAGS.lr_decay_rate, FLAGS.lr_decay_step, optimizer='RMSProp')
    print('=====Model Build=====\n')

    merged_summary = tf.summary.merge_all()

    """ Training
    """
    device_count = {'GPU': 1} if FLAGS.use_gpu else {'GPU': 0}
    with tf.Session(config=tf.ConfigProto(device_count=device_count,
                                          allow_soft_placement=True)) as sess:
        # Create tensorboard
        train_writer = tf.summary.FileWriter(train_log_save_dir, sess.graph)
        test_writer = tf.summary.FileWriter(test_log_save_dir, sess.graph)

        # Create model saver
        saver = tf.train.Saver(max_to_keep=None)

        # Init all vars
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        # Restore pretrained weights
        if FLAGS.pretrained_model != '':
            if FLAGS.pretrained_model.endswith('.pkl'):
                model.load_weights_from_file(FLAGS.pretrained_model, sess, finetune=True)

                # Check weights
                for variable in tf.trainable_variables():
                    with tf.variable_scope('', reuse=True):
                        var = tf.get_variable(variable.name.split(':0')[0])
                        print(variable.name, np.mean(sess.run(var)))

            else:
                saver.restore(sess, os.path.join(model_save_dir, FLAGS.pretrained_model))

                # check weights
                for variable in tf.trainable_variables():
                    with tf.variable_scope('', reuse=True):
                        var = tf.get_variable(variable.name.split(':0')[0])
                        print(variable.name, np.mean(sess.run(var)))

        for training_itr in range(FLAGS.training_iters):
            t1 = time.time()

            # Read one batch data
            batch_x_np, batch_joints_np = g.next()

            if FLAGS.normalize_img:
                # Normalize images
                batch_x_np = batch_x_np / 255.0 - 0.5
            else:
                batch_x_np -= 128.0

            # Generate heatmaps from joints
            batch_gt_heatmap_np = cpm_utils.make_heatmaps_from_joints(FLAGS.input_size,
                                                                      FLAGS.heatmap_size,
                                                                      FLAGS.joint_gaussian_variance,
                                                                      batch_joints_np)

            # Forward and update weights
            stage_losses_np, total_loss_np, _, summaries, current_lr, \
            stage_heatmap_np, global_step = sess.run([model.stage_loss,
                                                      model.total_loss,
                                                      model.train_op,
                                                      merged_summary,
                                                      model.cur_lr,
                                                      model.stage_heatmap,
                                                      model.global_step
                                                      ],
                                                     feed_dict={model.input_images: batch_x_np,
                                                                model.gt_hmap_placeholder: batch_gt_heatmap_np})

            # Show training info
            print_current_training_stats(global_step, current_lr, stage_losses_np, total_loss_np, time.time() - t1)

            # Write logs
            train_writer.add_summary(summaries, global_step)

            # Draw intermediate results
            if (global_step + 1) % 10 == 0:
                if FLAGS.color_channel == 'GRAY':
                    demo_img = np.repeat(batch_x_np[0], 3, axis=2)
                    if FLAGS.normalize_img:
                        demo_img += 0.5
                    else:
                        demo_img += 128.0
                        demo_img /= 255.0
                elif FLAGS.color_channel == 'RGB':
                    if FLAGS.normalize_img:
                        demo_img = batch_x_np[0] + 0.5
                    else:
                        demo_img += 128.0
                        demo_img /= 255.0
                else:
                    raise ValueError('Non support image type.')

                demo_stage_heatmaps = []
                for stage in range(FLAGS.cpm_stages):
                    demo_stage_heatmap = stage_heatmap_np[stage][0, :, :, 0:FLAGS.num_of_joints].reshape(
                        (FLAGS.heatmap_size, FLAGS.heatmap_size, FLAGS.num_of_joints))
                    demo_stage_heatmap = cv2.resize(demo_stage_heatmap, (FLAGS.input_size, FLAGS.input_size))
                    demo_stage_heatmap = np.amax(demo_stage_heatmap, axis=2)
                    demo_stage_heatmap = np.reshape(demo_stage_heatmap, (FLAGS.input_size, FLAGS.input_size, 1))
                    demo_stage_heatmap = np.repeat(demo_stage_heatmap, 3, axis=2)
                    demo_stage_heatmaps.append(demo_stage_heatmap)

                demo_gt_heatmap = batch_gt_heatmap_np[0, :, :, 0:FLAGS.num_of_joints].reshape(
                    (FLAGS.heatmap_size, FLAGS.heatmap_size, FLAGS.num_of_joints))
                demo_gt_heatmap = cv2.resize(demo_gt_heatmap, (FLAGS.input_size, FLAGS.input_size))
                demo_gt_heatmap = np.amax(demo_gt_heatmap, axis=2)
                demo_gt_heatmap = np.reshape(demo_gt_heatmap, (FLAGS.input_size, FLAGS.input_size, 1))
                demo_gt_heatmap = np.repeat(demo_gt_heatmap, 3, axis=2)

                if FLAGS.cpm_stages > 4:
                    upper_img = np.concatenate((demo_stage_heatmaps[0], demo_stage_heatmaps[1], demo_stage_heatmaps[2]),
                                               axis=1)
                    if FLAGS.normalize_img:
                        blend_img = 0.5 * demo_img + 0.5 * demo_gt_heatmap
                    else:
                        blend_img = 0.5 * demo_img / 255.0 + 0.5 * demo_gt_heatmap
                    lower_img = np.concatenate((demo_stage_heatmaps[FLAGS.cpm_stages - 1], demo_gt_heatmap, blend_img),
                                               axis=1)
                    demo_img = np.concatenate((upper_img, lower_img), axis=0)
                    cv2.imshow('current heatmap', (demo_img * 255).astype(np.uint8))
                    cv2.waitKey(1000)
                else:
                    upper_img = np.concatenate((demo_stage_heatmaps[FLAGS.cpm_stages - 1], demo_gt_heatmap, demo_img),
                                               axis=1)
                    cv2.imshow('current heatmap', (upper_img * 255).astype(np.uint8))
                    cv2.waitKey(1000)

            if (global_step + 1) % FLAGS.validation_iters == 0:
                mean_val_loss = 0
                cnt = 0

                while cnt < 10:
                    batch_x_np, batch_joints_np = g_eval.next()
                    # Normalize images
                    batch_x_np = batch_x_np / 255.0 - 0.5

                    batch_gt_heatmap_np = cpm_utils.make_heatmaps_from_joints(FLAGS.input_size,
                                                                              FLAGS.heatmap_size,
                                                                              FLAGS.joint_gaussian_variance,
                                                                              batch_joints_np)
                    total_loss_np, summaries = sess.run([model.total_loss, merged_summary],
                                                        feed_dict={model.input_images: batch_x_np,
                                                                   model.gt_hmap_placeholder: batch_gt_heatmap_np})
                    mean_val_loss += total_loss_np
                    cnt += 1

                print('\nValidation loss: {:>7.2f}\n'.format(mean_val_loss / cnt))
                test_writer.add_summary(summaries, global_step)

            # Save models
            if (global_step + 1) % FLAGS.model_save_iters == 0:
                saver.save(sess=sess, save_path=model_save_dir + '/' + FLAGS.network_def.split('.py')[0],
                           global_step=(global_step + 1))
                print('\nModel checkpoint saved...\n')

            # Finish training
            if global_step == FLAGS.training_iters:
                break
    print('Training done.')
예제 #2
0
def main(argv):
    """

    :param argv:
    :return:
    """
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    """ Create dirs for saving models and logs
    """
    model_path_suffix = os.path.join(
        FLAGS.network_def, 'input_{}_output_{}'.format(FLAGS.input_size,
                                                       FLAGS.heatmap_size),
        'joints_{}'.format(FLAGS.num_of_joints),
        'stages_{}'.format(FLAGS.cpm_stages),
        'init_{}_rate_{}_step_{}'.format(FLAGS.init_lr, FLAGS.lr_decay_rate,
                                         FLAGS.lr_decay_step))
    model_save_dir = os.path.join('models', 'weights', model_path_suffix)
    train_log_save_dir = os.path.join('models', 'logs', model_path_suffix,
                                      'train')
    test_log_save_dir = os.path.join('models', 'logs', model_path_suffix,
                                     'test')
    os.system('mkdir -p {}'.format(model_save_dir))
    os.system('mkdir -p {}'.format(train_log_save_dir))
    os.system('mkdir -p {}'.format(test_log_save_dir))
    """ Create data generator
    """
    #     g = Ensemble_data_generator.ensemble_data_generator(FLAGS.train_img_dir,
    #                                                         FLAGS.bg_img_dir,
    #                                                         FLAGS.batch_size, FLAGS.input_size, True, True,
    #                                                         FLAGS.augmentation_config, FLAGS.hnm, FLAGS.do_cropping)
    #     g_eval = Ensemble_data_generator.ensemble_data_generator(FLAGS.val_img_dir,
    #                                                              FLAGS.bg_img_dir,
    #                                                              FLAGS.batch_size, FLAGS.input_size, True, True,
    #                                                              FLAGS.augmentation_config, FLAGS.hnm, FLAGS.do_cropping)
    g = Ensemble_data_generator.ensemble_data_generator(
        "train.tfrecords", 5, 256, 64)
    g_eval = Ensemble_data_generator.ensemble_data_generator(
        "test.tfrecords", 5, 256, 64)
    """ Build network graph
    """
    model = cpm_model.CPM_Model(input_size=FLAGS.input_size,
                                heatmap_size=FLAGS.heatmap_size,
                                stages=FLAGS.cpm_stages,
                                joints=FLAGS.num_of_joints,
                                img_type=FLAGS.color_channel,
                                is_training=True)
    model.build_loss(FLAGS.init_lr,
                     FLAGS.lr_decay_rate,
                     FLAGS.lr_decay_step,
                     optimizer='RMSProp')
    print('=====Model Build=====\n')

    merged_summary = tf.summary.merge_all()
    """ Training
    """
    device_count = {'GPU': 1} if FLAGS.use_gpu else {'GPU': 0}
    with tf.Session(config=tf.ConfigProto(device_count=device_count,
                                          allow_soft_placement=True)) as sess:
        # Create tensorboard
        train_writer = tf.summary.FileWriter(train_log_save_dir, sess.graph)
        test_writer = tf.summary.FileWriter(test_log_save_dir, sess.graph)

        # Create model saver
        saver = tf.train.Saver(max_to_keep=None)

        # Init all vars
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        # Restore pretrained weights
        if FLAGS.pretrained_model != '':
            if FLAGS.pretrained_model.endswith('.pkl'):
                model.load_weights_from_file(FLAGS.pretrained_model,
                                             sess,
                                             finetune=True)

                # Check weights
                for variable in tf.trainable_variables():
                    with tf.variable_scope('', reuse=True):
                        var = tf.get_variable(variable.name.split(':0')[0])
                        print(variable.name, np.mean(sess.run(var)))

            else:
                saver.restore(
                    sess, os.path.join(model_save_dir, FLAGS.pretrained_model))

                # check weights
                for variable in tf.trainable_variables():
                    with tf.variable_scope('', reuse=True):
                        var = tf.get_variable(variable.name.split(':0')[0])
                        print(variable.name, np.mean(sess.run(var)))
        #saver.restore(sess, 'models/weights/cpm_hand')
        next_ele = g.next()
        for training_itr in range(FLAGS.training_iters):
            t1 = time.time()

            # Read one batch data
            #batch_x_np, batch_joints_np = g.next()

            batch_x_np, batch_joints_np = sess.run(next_ele)

            if FLAGS.normalize_img:
                # Normalize images
                batch_x_np = batch_x_np / 255.0 - 0.5
            else:
                batch_x_np -= 128.0

            # Generate heatmaps from joints
            batch_gt_heatmap_np = cpm_utils.make_heatmaps_from_joints(
                FLAGS.input_size, FLAGS.heatmap_size,
                FLAGS.joint_gaussian_variance, batch_joints_np)

            # Forward and update weights
            stage_losses_np, total_loss_np, _, summaries, current_lr, \
            stage_heatmap_np, global_step = sess.run([model.stage_loss,
                                                      model.total_loss,
                                                      model.train_op,
                                                      merged_summary,
                                                      model.cur_lr,
                                                      model.stage_heatmap,
                                                      model.global_step
                                                      ],
                                                     feed_dict={model.input_images: batch_x_np,
                                                                model.gt_hmap_placeholder: batch_gt_heatmap_np})

            # Show training info
            print_current_training_stats(global_step, current_lr,
                                         stage_losses_np, total_loss_np,
                                         time.time() - t1)

            # Write logs
            train_writer.add_summary(summaries, global_step)

            # TODO: each validation data set, do prediction

            # Draw intermediate results
            if (global_step + 1) % 100 == 0:
                if FLAGS.color_channel == 'GRAY':
                    demo_img = np.repeat(batch_x_np[0], 3, axis=2)
                    if FLAGS.normalize_img:
                        demo_img += 0.5
                    else:
                        demo_img += 128.0
                        demo_img /= 255.0
                elif FLAGS.color_channel == 'RGB':
                    if FLAGS.normalize_img:
                        demo_img = batch_x_np[0] + 0.5
                    else:
                        demo_img += 128.0
                        demo_img /= 255.0
                else:
                    raise ValueError('Non support image type.')

                demo_stage_heatmaps = []
                for stage in range(FLAGS.cpm_stages):
                    demo_stage_heatmap = stage_heatmap_np[stage][
                        0, :, :, 0:FLAGS.num_of_joints].reshape(
                            (FLAGS.heatmap_size, FLAGS.heatmap_size,
                             FLAGS.num_of_joints))
                    demo_stage_heatmap = cv2.resize(
                        demo_stage_heatmap,
                        (FLAGS.input_size, FLAGS.input_size))
                    demo_stage_heatmap = np.amax(demo_stage_heatmap, axis=2)
                    demo_stage_heatmap = np.reshape(
                        demo_stage_heatmap,
                        (FLAGS.input_size, FLAGS.input_size, 1))
                    demo_stage_heatmap = np.repeat(demo_stage_heatmap,
                                                   3,
                                                   axis=2)
                    demo_stage_heatmaps.append(demo_stage_heatmap)

                demo_gt_heatmap = batch_gt_heatmap_np[
                    0, :, :, 0:FLAGS.num_of_joints].reshape(
                        (FLAGS.heatmap_size, FLAGS.heatmap_size,
                         FLAGS.num_of_joints))
                demo_gt_heatmap = cv2.resize(
                    demo_gt_heatmap, (FLAGS.input_size, FLAGS.input_size))
                demo_gt_heatmap = np.amax(demo_gt_heatmap, axis=2)
                demo_gt_heatmap = np.reshape(
                    demo_gt_heatmap, (FLAGS.input_size, FLAGS.input_size, 1))
                demo_gt_heatmap = np.repeat(demo_gt_heatmap, 3, axis=2)

                if FLAGS.cpm_stages > 4:
                    upper_img = np.concatenate(
                        (demo_stage_heatmaps[0], demo_stage_heatmaps[1],
                         demo_stage_heatmaps[2]),
                        axis=1)
                    if FLAGS.normalize_img:
                        blend_img = 0.5 * demo_img + 0.5 * demo_gt_heatmap
                    else:
                        blend_img = 0.5 * demo_img / 255.0 + 0.5 * demo_gt_heatmap
                    lower_img = np.concatenate(
                        (demo_stage_heatmaps[FLAGS.cpm_stages - 1],
                         demo_gt_heatmap, blend_img),
                        axis=1)
                    demo_img = np.concatenate((upper_img, lower_img), axis=0)
                    #cv2.imshow('current heatmap', (demo_img * 255).astype(np.uint8))
                    #cv2.waitKey(1000)
                    cv2.imwrite(
                        "/home/qiaohe/convolutional-pose-machines-tensorflow/validation_img/"
                        + str(global_step) + ".jpg", demo_img * 255)

                else:
                    upper_img = np.concatenate(
                        (demo_stage_heatmaps[FLAGS.cpm_stages - 1],
                         demo_gt_heatmap, demo_img),
                        axis=1)
                    #cv2.imshow('current heatmap', (upper_img * 255).astype(np.uint8))
                    #cv2.waitKey(1000)
                    cv2.imwrite(
                        "/home/qiaohe/convolutional-pose-machines-tensorflow/validation_img/"
                        + str(global_step) + ".jpg", upper_img * 255)

            if (global_step + 1) % FLAGS.validation_iters == 0:
                mean_val_loss = 0
                cnt = 0

                while cnt < 10:
                    batch_x_np, batch_joints_np = g_eval.next()

                    batch_x_np, batch_joints_np = sess.run(
                        [batch_x_np, batch_joints_np])
                    # Normalize images
                    batch_x_np = batch_x_np / 255.0 - 0.5

                    batch_gt_heatmap_np = cpm_utils.make_heatmaps_from_joints(
                        FLAGS.input_size, FLAGS.heatmap_size,
                        FLAGS.joint_gaussian_variance, batch_joints_np)
                    total_loss_np, summaries = sess.run(
                        [model.total_loss, merged_summary],
                        feed_dict={
                            model.input_images: batch_x_np,
                            model.gt_hmap_placeholder: batch_gt_heatmap_np
                        })
                    mean_val_loss += total_loss_np
                    cnt += 1

                print('\nValidation loss: {:>7.2f}\n'.format(mean_val_loss /
                                                             cnt))
                test_writer.add_summary(summaries, global_step)

            # Save models
            if (global_step + 1) % FLAGS.model_save_iters == 0:
                saver.save(sess=sess,
                           save_path=model_save_dir + '/' +
                           FLAGS.network_def.split('.py')[0],
                           global_step=(global_step + 1))
                print('\nModel checkpoint saved...\n')

            # Finish training
            if global_step == FLAGS.training_iters:
                break
    print('Training done.')
예제 #3
0
def main():
    model = cpm_model.CPM_Model(input_size=FLAGS.input_size,
                                heatmap_size=FLAGS.heatmap_size,
                                stages=FLAGS.cpm_stages,
                                joints=FLAGS.num_of_joints,
                                img_type=FLAGS.color_channel,
                                is_training=False)
    model.build_loss(FLAGS.init_lr,
                     FLAGS.lr_decay_rate,
                     FLAGS.lr_decay_step,
                     optimizer='RMSProp')
    saver = tf.train.Saver()

    g = Ensemble_data_generator.ensemble_data_generator(
        FLAGS.train_img_dir, None, FLAGS.batch_size, FLAGS.input_size, True,
        False, FLAGS.augmentation_config, False)

    device_count = {'GPU': 1} if FLAGS.use_gpu else {'GPU': 0}
    sess_config = tf.ConfigProto(device_count=device_count)
    sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True
    with tf.Session(config=sess_config) as sess:

        model_path_suffix = os.path.join(
            FLAGS.network_def,
            'input_{}_output_{}'.format(FLAGS.input_size, FLAGS.heatmap_size),
            'joints_{}'.format(FLAGS.num_of_joints),
            'stages_{}'.format(FLAGS.cpm_stages),
            'init_{}_rate_{}_step_{}'.format(FLAGS.init_lr,
                                             FLAGS.lr_decay_rate,
                                             FLAGS.lr_decay_step))
        model_save_dir = os.path.join('models', 'weights', model_path_suffix)
        print('Load model from [{}]'.format(
            os.path.join(model_save_dir, FLAGS.model_path)))
        if FLAGS.model_path.endswith('pkl'):
            model.load_weights_from_file(FLAGS.model_path, sess, False)
        else:
            saver.restore(sess, os.path.join(model_save_dir, FLAGS.model_path))
        print('Load model done')

        bbox_offset = 100
        for person_dir in os.listdir(FLAGS.train_img_dir):
            json_file_path = os.path.join(FLAGS.train_img_dir, person_dir,
                                          'attr_data.json')
            hnm_json_list = [[] for _ in range(11)]

            with open(json_file_path, 'r') as f:
                json_file = json.load(f)

            loss_cnt = 0
            img_cnt = 0
            hnm_cnt = 0
            for cam_id in range(11):
                for img_id in range(len(json_file[cam_id])):
                    img_path = os.path.join(FLAGS.train_img_dir, person_dir,
                                            'undistorted_img',
                                            json_file[cam_id][img_id]['name'])
                    img = cv2.imread(img_path)

                    # Read joints
                    hand_2d_joints = np.zeros(shape=(21, 2))
                    bbox = json_file[cam_id][img_id]['bbox']
                    bbox[0] = max(bbox[0] - bbox_offset, 0)
                    bbox[1] = max(bbox[1] - bbox_offset, 0)
                    bbox[2] = min(bbox[2] + bbox_offset, img.shape[0])
                    bbox[3] = min(bbox[3] + bbox_offset, img.shape[1])
                    img = img[bbox[1]:bbox[3], bbox[0]:bbox[2]]

                    for i, finger_name in enumerate(
                        ['thumb', 'index', 'middle', 'ring', 'pinky']):
                        for j, joint_name in enumerate(
                            ['tip', 'dip', 'pip', 'mcp']):
                            hand_2d_joints[i * 4 + j, :] = \
                            json_file[cam_id][img_id][finger_name][joint_name]['pose2']
                    hand_2d_joints[
                        20, :] = json_file[cam_id][img_id]['wrist']['pose2']
                    hand_2d_joints[:, 0] -= bbox[0]
                    hand_2d_joints[:, 1] -= bbox[1]

                    # for i in range(hand_2d_joints.shape[0]):
                    #     cv2.circle(img, (int(hand_2d_joints[i][0]), int(hand_2d_joints[i][1])), 5, (0, 255, 0), -1)
                    # print(img_path)
                    img = img / 255.0 - 0.5

                    img, hand_2d_joints = scale_square_data(
                        img, hand_2d_joints, FLAGS.input_size)
                    # for i in range(hand_2d_joints.shape[0]):
                    #     cv2.circle(img, (int(hand_2d_joints[i][0]), int(hand_2d_joints[i][1])), 5, (0, 255, 0), -1)
                    # cv2.imshow('', img)
                    # cv2.waitKey(0)

                    img = np.expand_dims(img, axis=0)
                    hand_2d_joints = np.expand_dims(hand_2d_joints, axis=0)

                    gt_heatmap_np = cpm_utils.make_heatmaps_from_joints(
                        FLAGS.input_size, FLAGS.heatmap_size,
                        FLAGS.joint_gaussian_variance, hand_2d_joints)

                    loss, = sess.run(
                        [model.total_loss],
                        feed_dict={
                            model.input_images: img,
                            model.gt_hmap_placeholder: gt_heatmap_np
                        })

                    # loss_cnt += loss
                    img_cnt += 1
                    # print(img_path, float(loss_cnt)/ img_cnt)

                    if loss > 150.0:
                        hnm_json_list[cam_id].append(json_file[cam_id][img_id])
                        hnm_cnt += 1
                        print('hnm cnt {} / {}'.format(hnm_cnt, img_cnt))

            with open(
                    os.path.join(FLAGS.train_img_dir, person_dir,
                                 'attr_data_hnm.json'), 'wb') as f:
                json.dump(hnm_json_list, f)
                print('write done with {}'.format(person_dir))
예제 #4
0
def main(argv):
    """

    :param argv:
    :return:
    """
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    """ Create dirs for saving models and logs
    """
    model_path_suffix = os.path.join(FLAGS.network_def,
                                     'input_{}_output_{}'.format(FLAGS.input_size, FLAGS.heatmap_size),
                                     'joints_{}'.format(FLAGS.num_of_joints),
                                     'stages_{}'.format(FLAGS.cpm_stages),
                                     'init_{}_rate_{}_step_{}'.format(FLAGS.init_lr, FLAGS.lr_decay_rate,
                                                                      FLAGS.lr_decay_step)
                                     )
    model_save_dir = os.path.join('models',
                                  'weights',
                                  model_path_suffix)
    train_log_save_dir = os.path.join('models',
                                      'logs',
                                      model_path_suffix,
                                      'train')
    test_log_save_dir = os.path.join('models',
                                     'logs',
                                     model_path_suffix,
                                     'test')
    os.system('mkdir -p {}'.format(model_save_dir))
    os.system('mkdir -p {}'.format(train_log_save_dir))
    os.system('mkdir -p {}'.format(test_log_save_dir))

    """ Create data generator
    """
    # 读取标注数据
    anns_file = open(FLAGS.anns_path, 'r')
    anns = json.load(anns_file)

    # anns的前1000作为validation
    anns_valid = anns[:1000]
    anns_train = anns[1000:]

    g = process_data.train_generator(FLAGS.input_size, FLAGS.batch_size, FLAGS.train_img_dir, anns_train)
    g_valid = process_data.train_generator(FLAGS.input_size, FLAGS.batch_size, FLAGS.train_img_dir, anns_valid)

    # 放弃centermap
    # center_map_batch = cpm_utils.make_center_maps(FLAGS.input_size, FLAGS.batch_size, FLAGS.center_radius)

    """ Build network graph
    """
    model = cpm_model.CPM_Model(input_size=FLAGS.input_size,
                                heatmap_size=FLAGS.heatmap_size,
                                stages=FLAGS.cpm_stages,
                                joints=FLAGS.num_of_joints + 1,
                                batch_size = FLAGS.batch_size,
                                img_type=FLAGS.color_channel,
                                is_training=True)

    model.build_loss(FLAGS.init_lr, FLAGS.lr_decay_rate, FLAGS.lr_decay_step, optimizer='Adam')
    print('=====Model Build=====\n')

    merged_summary = tf.summary.merge_all()

    """ Training
    """
    device_count = {'GPU': 1} if FLAGS.use_gpu else {'GPU': 0}
    with tf.Session(config=tf.ConfigProto(device_count=device_count,
                                          allow_soft_placement=True,
                                          gpu_options=tf.GPUOptions(allow_growth=True))) as sess:
        # Create tensorboard
        train_writer = tf.summary.FileWriter(train_log_save_dir, sess.graph)
        test_writer = tf.summary.FileWriter(test_log_save_dir, sess.graph)

        # Create model saver
        saver = tf.train.Saver(max_to_keep=None)

        # Init all vars
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        # Restore pretrained weights
        if FLAGS.pretrained_model != '':
            if not os.path.exists(FLAGS.pretrained_model):
                raise IOError('Model doses not exist!')
            else:
                if FLAGS.pretrained_model.endswith('.pkl'):
                    model.load_weights_from_file(FLAGS.pretrained_model, sess, finetune=True)
                     # Check weights
                    for variable in tf.trainable_variables():
                        with tf.variable_scope('', reuse=True):
                            var = tf.get_variable(variable.name.split(':0')[0])
                            print(variable.name, np.mean(sess.run(var)))

                else:
                    saver.restore(sess, os.path.join(model_save_dir, FLAGS.pretrained_model))

        for training_itr in range(FLAGS.training_iters):
            t1 = time.time()

            # Read one batch data
            batch_x_np, batch_joints_np = g.__next__()
            
            #cv2.imwrite('/root/COCO/img.jpg', batch_x_np[0])
            #print(batch_x_np)

            if FLAGS.normalize_img:
                # Normalize images
                batch_x_np = batch_x_np / 256.0 - 0.5
            else:
                batch_x_np -= 128.0
            
            # Generate heatmaps from joints
            batch_gt_heatmap_np = cpm_utils.make_heatmaps_from_joints(FLAGS.input_size,
                                                                      FLAGS.heatmap_size,
                                                                      FLAGS.joint_gaussian_variance,
                                                                      batch_joints_np)
            # 可视化图像与和heatmap
            #tmp = np.zeros([46,46])
            #for i in range(18):
            #    tmp += batch_gt_heatmap_np[0][:,:,i]*255
            #tmp_resize = cv2.resize(tmp, (368,368))
            #cv2.imwrite('/root/COCO/hp1.jpg', tmp_resize)
            #exit(0)

            # Forward and update weights
            """
            stage_losses_np, total_loss_np, _, summaries, current_lr, \
            stage_heatmap_np, global_step = sess.run([model.stage_loss,
                                                      model.total_loss,
                                                      model.train_op,
                                                      merged_summary,
                                                      model.cur_lr,
                                                      model.stage_heatmap,
                                                      model.global_step
                                                      ],
                                                     feed_dict={model.input_images: batch_x_np,
                                                                model.gt_hmap_placeholder: batch_gt_heatmap_np,
                                                                model.cmap_placeholder: center_map_batch})
            """
            stage_losses_np, total_loss_np, _, summaries, current_lr, \
            stage_heatmap_np, global_step = sess.run([model.stage_loss,
                                                      model.total_loss,
                                                      model.train_op,
                                                      merged_summary,
                                                      model.cur_lr,
                                                      model.stage_heatmap,
                                                      model.global_step
                                                      ],
                                                     feed_dict={model.input_images: batch_x_np,
                                                                model.gt_hmap_placeholder: batch_gt_heatmap_np})

            # Show training info
            print_current_training_stats(global_step, current_lr, stage_losses_np, total_loss_np, time.time() - t1)

            # Write logs
            train_writer.add_summary(summaries, global_step)

            
            # 可视化热度图
            """
            # Draw intermediate results
            if (global_step + 1) % 10 == 0:
                if FLAGS.color_channel == 'GRAY':
                    demo_img = np.repeat(batch_x_np[0], 3, axis=2)
                    if FLAGS.normalize_img:
                        demo_img += 0.5
                    else:
                        demo_img += 128.0
                        demo_img /= 255.0
                elif FLAGS.color_channel == 'RGB':
                    if FLAGS.normalize_img:
                        demo_img = batch_x_np[0] + 0.5
                    else:
                        demo_img += 128.0
                        demo_img /= 255.0
                else:
                    raise ValueError('Non support image type.')

                demo_stage_heatmaps = []
                for stage in range(FLAGS.cpm_stages):
                    demo_stage_heatmap = stage_heatmap_np[stage][0, :, :, 0:FLAGS.num_of_joints].reshape(
                        (FLAGS.heatmap_size, FLAGS.heatmap_size, FLAGS.num_of_joints))
                    demo_stage_heatmap = cv2.resize(demo_stage_heatmap, (FLAGS.input_size, FLAGS.input_size))
                    demo_stage_heatmap = np.amax(demo_stage_heatmap, axis=2)
                    demo_stage_heatmap = np.reshape(demo_stage_heatmap, (FLAGS.input_size, FLAGS.input_size, 1))
                    demo_stage_heatmap = np.repeat(demo_stage_heatmap, 3, axis=2)
                    demo_stage_heatmaps.append(demo_stage_heatmap)

                demo_gt_heatmap = batch_gt_heatmap_np[0, :, :, 0:FLAGS.num_of_joints].reshape(
                    (FLAGS.heatmap_size, FLAGS.heatmap_size, FLAGS.num_of_joints))
                demo_gt_heatmap = cv2.resize(demo_gt_heatmap, (FLAGS.input_size, FLAGS.input_size))
                demo_gt_heatmap = np.amax(demo_gt_heatmap, axis=2)
                demo_gt_heatmap = np.reshape(demo_gt_heatmap, (FLAGS.input_size, FLAGS.input_size, 1))
                demo_gt_heatmap = np.repeat(demo_gt_heatmap, 3, axis=2)

                if FLAGS.cpm_stages > 4:
                    upper_img = np.concatenate((demo_stage_heatmaps[0], demo_stage_heatmaps[1], demo_stage_heatmaps[2]),
                                               axis=1)
                    if FLAGS.normalize_img:
                        blend_img = 0.5 * demo_img + 0.5 * demo_gt_heatmap
                    else:
                        blend_img = 0.5 * demo_img / 255.0 + 0.5 * demo_gt_heatmap
                    lower_img = np.concatenate((demo_stage_heatmaps[FLAGS.cpm_stages - 1], demo_gt_heatmap, blend_img),
                                               axis=1)
                    demo_img = np.concatenate((upper_img, lower_img), axis=0)
                    cv2.imshow('current heatmap', (demo_img * 255).astype(np.uint8))
                    cv2.waitKey(1000)
                else:
                    upper_img = np.concatenate((demo_stage_heatmaps[FLAGS.cpm_stages - 1], demo_gt_heatmap, demo_img),
                                               axis=1)
                    cv2.imshow('current heatmap', (upper_img * 255).astype(np.uint8))
                    cv2.waitKey(1000)
            """

            if (global_step + 1) % FLAGS.validation_iters == 0:
                mean_val_loss = 0
                cnt = 0

                while cnt < 30:
                    batch_x_np, batch_joints_np = g_valid.__next__()
                    # Normalize images
                    batch_x_np = batch_x_np / 256.0 - 0.5

                    batch_gt_heatmap_np = cpm_utils.make_heatmaps_from_joints(FLAGS.input_size,
                                                                              FLAGS.heatmap_size,
                                                                              FLAGS.joint_gaussian_variance,
                                                                              batch_joints_np)
                    """
                    total_loss_np, summaries = sess.run([model.total_loss, merged_summary],
                                                        feed_dict={model.input_images: batch_x_np,
                                                                   model.gt_hmap_placeholder: batch_gt_heatmap_np,
                                                                   model.cmap_placeholder: center_map_batch})
                    """
                    total_loss_np, summaries = sess.run([model.total_loss, merged_summary],
                                                        feed_dict={model.input_images: batch_x_np,
                                                                   model.gt_hmap_placeholder: batch_gt_heatmap_np})

                    mean_val_loss += total_loss_np
                    cnt += 1

                print('\nValidation loss: {:>7.2f}\n'.format(mean_val_loss / cnt))
                test_writer.add_summary(summaries, global_step)

            # Save models
            if (global_step + 1) % FLAGS.model_save_iters == 0:
                saver.save(sess=sess, save_path=model_save_dir + '/' + FLAGS.network_def,
                           global_step=(global_step + 1))
                print('\nModel checkpoint saved...\n')

            # Finish training
            if global_step == FLAGS.training_iters:
                break
    print('Training done.')
def main():
    model = cpm_model.CPM_Model(input_size=FLAGS.input_size,
                                heatmap_size=FLAGS.heatmap_size,
                                stages=FLAGS.cpm_stages,
                                joints=FLAGS.num_of_joints,
                                img_type=FLAGS.color_channel,
                                is_training=False)
    model.build_loss(FLAGS.init_lr, FLAGS.lr_decay_rate, FLAGS.lr_decay_step, optimizer='RMSProp')
    saver = tf.train.Saver()

    g = Ensemble_data_generator.ensemble_data_generator(FLAGS.train_img_dir,
                                                        None,
                                                        FLAGS.batch_size, FLAGS.input_size, True, False,
                                                        FLAGS.augmentation_config, False)

    device_count = {'GPU': 1} if FLAGS.use_gpu else {'GPU': 0}
    sess_config = tf.ConfigProto(device_count=device_count)
    sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5
    sess_config.gpu_options.allow_growth = True
    sess_config.allow_soft_placement = True
    with tf.Session(config=sess_config) as sess:

        model_path_suffix = os.path.join(FLAGS.network_def,
                                         'input_{}_output_{}'.format(FLAGS.input_size, FLAGS.heatmap_size),
                                         'joints_{}'.format(FLAGS.num_of_joints),
                                         'stages_{}'.format(FLAGS.cpm_stages),
                                         'init_{}_rate_{}_step_{}'.format(FLAGS.init_lr, FLAGS.lr_decay_rate,
                                                                          FLAGS.lr_decay_step)
                                         )
        model_save_dir = os.path.join('models',
                                      'weights',
                                      model_path_suffix)
        print('Load model from [{}]'.format(os.path.join(model_save_dir, FLAGS.model_path)))
        if FLAGS.model_path.endswith('pkl'):
            model.load_weights_from_file(FLAGS.model_path, sess, False)
        else:
            saver.restore(sess, os.path.join(model_save_dir, FLAGS.model_path))
        print('Load model done')


        bbox_offset = 100
        for person_dir in os.listdir(FLAGS.train_img_dir):
            json_file_path = os.path.join(FLAGS.train_img_dir, person_dir, 'attr_data.json')
            hnm_json_list = [[] for _ in range(11)]

            with open(json_file_path, 'r') as f:
                json_file = json.load(f)

            loss_cnt = 0
            img_cnt = 0
            hnm_cnt = 0
            for cam_id in range(11):
                for img_id in range(len(json_file[cam_id])):
                    img_path = os.path.join(FLAGS.train_img_dir,
                                            person_dir,
                                            'undistorted_img',
                                            json_file[cam_id][img_id]['name'])
                    img = cv2.imread(img_path)

                    # Read joints
                    hand_2d_joints = np.zeros(shape=(21, 2))
                    bbox = json_file[cam_id][img_id]['bbox']
                    bbox[0] = max(bbox[0] - bbox_offset, 0)
                    bbox[1] = max(bbox[1] - bbox_offset, 0)
                    bbox[2] = min(bbox[2] + bbox_offset, img.shape[0])
                    bbox[3] = min(bbox[3] + bbox_offset, img.shape[1])
                    img = img[bbox[1]:bbox[3],
                          bbox[0]:bbox[2]]

                    for i, finger_name in enumerate(['thumb', 'index', 'middle', 'ring', 'pinky']):
                        for j, joint_name in enumerate(['tip', 'dip', 'pip', 'mcp']):
                            hand_2d_joints[i * 4 + j, :] = \
                            json_file[cam_id][img_id][finger_name][joint_name]['pose2']
                    hand_2d_joints[20, :] = json_file[cam_id][img_id]['wrist']['pose2']
                    hand_2d_joints[:, 0] -= bbox[0]
                    hand_2d_joints[:, 1] -= bbox[1]

                    # for i in range(hand_2d_joints.shape[0]):
                    #     cv2.circle(img, (int(hand_2d_joints[i][0]), int(hand_2d_joints[i][1])), 5, (0, 255, 0), -1)
                    # print(img_path)
                    img = img / 255.0 - 0.5

                    img, hand_2d_joints = scale_square_data(img, hand_2d_joints, FLAGS.input_size)
                    # for i in range(hand_2d_joints.shape[0]):
                    #     cv2.circle(img, (int(hand_2d_joints[i][0]), int(hand_2d_joints[i][1])), 5, (0, 255, 0), -1)
                    # cv2.imshow('', img)
                    # cv2.waitKey(0)

                    img = np.expand_dims(img, axis=0)
                    hand_2d_joints = np.expand_dims(hand_2d_joints, axis=0)

                    gt_heatmap_np = cpm_utils.make_heatmaps_from_joints(FLAGS.input_size,
                                                                              FLAGS.heatmap_size,
                                                                              FLAGS.joint_gaussian_variance,
                                                                              hand_2d_joints)


                    loss, = sess.run([model.total_loss], feed_dict={model.input_images: img,
                                                                    model.gt_hmap_placeholder: gt_heatmap_np})

                    # loss_cnt += loss
                    img_cnt += 1
                    # print(img_path, float(loss_cnt)/ img_cnt)

                    if loss > 150.0:
                        hnm_json_list[cam_id].append(json_file[cam_id][img_id])
                        hnm_cnt += 1
                        print('hnm cnt {} / {}'.format(hnm_cnt, img_cnt))

            with open(os.path.join(FLAGS.train_img_dir, person_dir, 'attr_data_hnm.json'), 'wb') as f:
                json.dump(hnm_json_list, f)
                print('write done with {}'.format(person_dir))