def main(option):
    start_time = time.time()
    cfg = load_config("demo/pose_cfg_multi.yaml")
    dataset = create_dataset(cfg)
    sm = SpatialModel(cfg)
    sm.load()
    tf.reset_default_graph()
    draw_multi = PersonDraw()
    sess, inputs, outputs = predict.setup_pose_prediction(cfg)
    fps_time = 0
    # Read image from file
    slopes = {}
    k = 0
    cap = cv2.VideoCapture("http://192.168.43.31:8081")
    cap_user = cv2.VideoCapture('/dev/video0')
    cap = cap_user

    i = 0
    while (True):
        ret, orig_frame = cap.read()
        ret2, orig_frame_user = cap_user.read()
        if i % 25 == 0:
            #frame=orig_frame
            frame = cv2.resize(orig_frame, (0, 0), fx=0.50, fy=0.50)
            user_frame = cv2.resize(orig_frame_user, (0, 0), fx=0.50, fy=0.50)
            co1 = run_predict(frame, sess, outputs, inputs, cfg, dataset, sm,
                              draw_multi)
            print("CO1            ", co1)
            user_co1 = run_predict(user_frame, sess, outputs, inputs, cfg,
                                   dataset, sm, draw_multi)
            print("USER_CO1            ", user_co1)
            print("CO1            ", co1)
            k = None
            try:
                slope_reqd, slope_user = slope_calc(co1, user_co1)
                k, s = compare_images(slope_reqd, slope_user, 0.75)
            except IndexError:
                #if len(co1)!=len(user_co1):
                print("Except condition")
                pass
            vibrate(k)
            frame = cv2.resize(frame, (0, 0), fx=2.0, fy=2.0)
            user_frame = cv2.resize(user_frame, (0, 0), fx=2.0, fy=2.0)
            cv2.putText(user_frame,
                        "FPS: %f" % (1.0 / (time.time() - fps_time)), (10, 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
            cv2.imshow('user_frame', user_frame)
            cv2.imshow('frame', frame)
            fps_time = time.time()
            #visualize.waitforbuttonpress()
            if cv2.waitKey(10) == ord('q'):
                break
    elapsed = time.time() - start_time
    cap.release()
    cap_user.release()
    cv2.destroyAllWindows()
def main(option):
    cfg = load_config("demo/pose_cfg_multi.yaml")
    dataset = create_dataset(cfg)
    sm = SpatialModel(cfg)
    sm.load()
    tf.reset_default_graph()
    draw_multi = PersonDraw()
    sess, inputs, outputs = predict.setup_pose_prediction(cfg)
    fps_time = 0
    # Read image from file
    cap = cv2.VideoCapture('msgifs/icon4.gif')
    cap_user = cv2.VideoCapture('user.mp4')
    i = 0
    while (True):
        ret, orig_frame = cap.read()
        ret2, orig_frame_user = cap_user.read()
        if i % 25 == 0:

            frame = cv2.resize(orig_frame, (0, 0), fx=0.50, fy=0.50)
            user_frame = cv2.resize(orig_frame_user, (0, 0), fx=0.50, fy=0.50)
            co1 = run_predict(frame, sess, inputs, outputs, cfg, dataset, sm,
                              draw_multi)
            user_co1 = run_predict(user_frame, sess, inputs, outputs, cfg,
                                   dataset, sm, draw_multi)
            try:
                slope_reqd = slope_calc(co1)
                slope_user = slope_calc(user_co1)
                compare_images(slope_reqd, slope_user, 0.1)
            except IndexError:
                #if len(co1)!=len(user_co1):
                #messagebox.showinfo("Title", "Please adjust camera to show your keypoints")
                pass
            #frame = cv2.resize(frame, (0, 0), fx=2.0, fy=2.0)
            #user_frame = cv2.resize(user_frame, (0, 0), fx=2.0, fy=2.0)
            cv2.putText(user_frame,
                        "FPS: %f" % (1.0 / (time.time() - fps_time)), (10, 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

            cv2.imshow('user_frame', user_frame)
            cv2.imshow('frame', frame)
            fps_time = time.time()
            #visualize.waitforbuttonpress()
            if cv2.waitKey(10) == ord('q'):
                break
    cap.release()
    cap_user.release()
    cv2.destroyAllWindows()
    cap_user.release()
Ejemplo n.º 3
0
# for object-tracker
import dlib

# import video_pose

####################

cfg = load_config("demo/pose_cfg_multi.yaml")

dataset = create_dataset(cfg)

sm = SpatialModel(cfg)
sm.load()

draw_multi = PersonDraw()

# Load and setup CNN part detector
sess, inputs, outputs = predict.setup_pose_prediction(cfg)

##########
## Get the source of video

parser = ap.ArgumentParser()
parser.add_argument('-f', "--videoFile", help="Path to Video File")
parser.add_argument('-w', "--videoWidth", help="Width of Output Video")
parser.add_argument('-o', "--videoType", help="Extension of Output Video")

args = vars(parser.parse_args())

if args["videoFile"] is not None:
def test_net(visualise, cache_scoremaps, development):
    logging.basicConfig(level=logging.INFO)

    cfg = load_config()
    dataset = create_dataset(cfg)
    dataset.set_shuffle(False)

    sm = SpatialModel(cfg)
    sm.load()

    draw_multi = PersonDraw()

    from_cache = "cached_scoremaps" in cfg
    if not from_cache:
        sess, inputs, outputs = setup_pose_prediction(cfg)

    if cache_scoremaps:
        out_dir = cfg.scoremap_dir
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

    pairwise_stats = dataset.pairwise_stats
    num_images = dataset.num_images if not development else min(
        10, dataset.num_images)
    coco_results = []

    for k in range(num_images):
        print('processing image {}/{}'.format(k, num_images - 1))

        batch = dataset.next_batch()

        cache_name = "{}.mat".format(batch[Batch.data_item].coco_id)

        if not from_cache:
            outputs_np = sess.run(outputs,
                                  feed_dict={inputs: batch[Batch.inputs]})
            scmap, locref, pairwise_diff = extract_cnn_output(
                outputs_np, cfg, pairwise_stats)

            if cache_scoremaps:
                if visualise:
                    img = np.squeeze(batch[Batch.inputs]).astype('uint8')
                    pose = argmax_pose_predict(scmap, locref, cfg.stride)
                    arrows = argmax_arrows_predict(scmap, locref,
                                                   pairwise_diff, cfg.stride)
                    visualize.show_arrows(cfg, img, pose, arrows)
                    visualize.waitforbuttonpress()
                    continue

                out_fn = os.path.join(out_dir, cache_name)
                dict = {
                    'scoremaps': scmap.astype('float32'),
                    'locreg_pred': locref.astype('float32'),
                    'pairwise_diff': pairwise_diff.astype('float32')
                }
                scipy.io.savemat(out_fn, mdict=dict)
                continue
        else:
            # cache_name = '1.mat'
            full_fn = os.path.join(cfg.cached_scoremaps, cache_name)
            mlab = scipy.io.loadmat(full_fn)
            scmap = mlab["scoremaps"]
            locref = mlab["locreg_pred"]
            pairwise_diff = mlab["pairwise_diff"]

        detections = extract_detections(cfg, scmap, locref, pairwise_diff)
        unLab, pos_array, unary_array, pwidx_array, pw_array = eval_graph(
            sm, detections)
        person_conf_multi = get_person_conf_multicut(sm, unLab, unary_array,
                                                     pos_array)

        if visualise:
            img = np.squeeze(batch[Batch.inputs]).astype('uint8')
            # visualize.show_heatmaps(cfg, img, scmap, pose)
            """
            # visualize part detections after NMS
            visim_dets = visualize_detections(cfg, img, detections)
            plt.imshow(visim_dets)
            plt.show()
            visualize.waitforbuttonpress()
            """

            #            """
            visim_multi = img.copy()
            draw_multi.draw(visim_multi, dataset, person_conf_multi)

            plt.imshow(visim_multi)
            plt.show()
            visualize.waitforbuttonpress()
        #            """

        if cfg.use_gt_segm:
            coco_img_results = pose_predict_with_gt_segm(
                scmap, locref, cfg.stride, batch[Batch.data_item].gt_segm,
                batch[Batch.data_item].coco_id)
            coco_results += coco_img_results
            if len(coco_img_results):
                dataset.visualize_coco(coco_img_results,
                                       batch[Batch.data_item].visibilities)

    if cfg.use_gt_segm:
        with open('predictions_with_segm.json', 'w') as outfile:
            json.dump(coco_results, outfile)

    sess.close()
Ejemplo n.º 5
0
from dataset.pose_dataset import data_to_input

from multiperson.detections import extract_detections
from multiperson.predict import SpatialModel, eval_graph, get_person_conf_multicut
from multiperson.visualize import PersonDraw, visualize_detections

import matplotlib.pyplot as plt

cfg = load_config("demo/pose_cfg_multi.yaml")

dataset = create_dataset(cfg)

sm = SpatialModel(cfg)
sm.load()

draw_multi = PersonDraw()

# Load and setup CNN part detector
sess, inputs, outputs = predict.setup_pose_prediction(cfg)

# Read image from file
file_name = "demo/image_multi.png"
image = imageio.imread(file_name, mode='RGB')

image_batch = data_to_input(image)

# Compute prediction with the CNN
outputs_np = sess.run(outputs, feed_dict={inputs: image_batch})
scmap, locref, pairwise_diff = predict.extract_cnn_output(
    outputs_np, cfg, dataset.pairwise_stats)
Ejemplo n.º 6
0
    plt.show()

    return (s, m)


# In[11]:

tf.reset_default_graph()
cfg = load_config("demo/pose_cfg_multi.yaml")

dataset = create_dataset(cfg)

sm = SpatialModel(cfg)
sm.load()

draw_multi = PersonDraw()

# Load and setup CNN part detector
sess, inputs, outputs = predict.setup_pose_prediction(cfg)

# Read image from file
file_name = "demo/try.jpeg"
file_name1 = 'demo/try2.jpeg'
image = imread(file_name, 0)
image2 = imread(file_name1, 0)
cap = cv2.VideoCapture('demo/seed.mp4')
i = 0
cap1 = cv2.VideoCapture('demo/comp.mp4')
while True:
    if i % 8 == 0:
        ret, orig_frame = cap.read()
from multiperson.detections import extract_detections
from multiperson.predict import SpatialModel, eval_graph, get_person_conf_multicut
from multiperson.visualize import PersonDraw, visualize_detections

import matplotlib.pyplot as plt


cfg = load_config("demo/pose_cfg_multi.yaml")

dataset = create_dataset(cfg)

sm = SpatialModel(cfg)
sm.load()

draw_multi = PersonDraw()

# Load and setup CNN part detector
sess, inputs, outputs = predict.setup_pose_prediction(cfg)

# Read image from file
file_name = "demo/image_multi.png"
image = imread(file_name, mode='RGB')

image_batch = data_to_input(image)

# Compute prediction with the CNN
outputs_np = sess.run(outputs, feed_dict={inputs: image_batch})
scmap, locref, pairwise_diff = predict.extract_cnn_output(outputs_np, cfg, dataset.pairwise_stats)

detections = extract_detections(cfg, scmap, locref, pairwise_diff)
Ejemplo n.º 8
0
def main():
    start_time=time.time()
    print("main hai")
    tf.reset_default_graph()
    cfg = load_config("demo/pose_cfg_multi.yaml")
    dataset = create_dataset(cfg)
    sm = SpatialModel(cfg)
    sm.load()
    draw_multi = PersonDraw()
    # Load and setup CNN part detector
    sess, inputs, outputs = predict.setup_pose_prediction(cfg)

    # Read image from file
    dir=os.listdir("stick")
    k=0
    cap=cv2.VideoCapture(0)
    i=0
    while (cap.isOpened()):
            if i%20 == 0:                   
                ret, orig_frame= cap.read()
                if ret==True:
                    frame = cv2.resize(orig_frame, (0, 0), fx=0.30, fy=0.30)
                    image= frame
                    sse=0
                    mse=0
                    
                    image_batch = data_to_input(frame)

                    # Compute prediction with the CNN
                    outputs_np = sess.run(outputs, feed_dict={inputs: image_batch})

                    scmap, locref, pairwise_diff = predict.extract_cnn_output(outputs_np, cfg, dataset.pairwise_stats)

                    detections = extract_detections(cfg, scmap, locref, pairwise_diff)

                    unLab, pos_array, unary_array, pwidx_array, pw_array = eval_graph(sm, detections)

                    person_conf_multi = get_person_conf_multicut(sm, unLab, unary_array, pos_array)
                    img = np.copy(image)
                    #coor = PersonDraw.draw()
                    visim_multi = img.copy()
                    co1=draw_multi.draw(visim_multi, dataset, person_conf_multi)
                    plt.imshow(visim_multi)
                    plt.show()
                    visualize.waitforbuttonpress()
                    #print("this is draw : ", co1)
                    if k==1:
                        qwr = np.zeros((1920,1080,3), np.uint8)

                        cv2.line(qwr, co1[5][0], co1[5][1],(255,0,0),3)
                        cv2.line(qwr, co1[7][0], co1[7][1],(255,0,0),3)
                        cv2.line(qwr, co1[6][0], co1[6][1],(255,0,0),3)
                        cv2.line(qwr, co1[4][0], co1[4][1],(255,0,0),3)

                        cv2.line(qwr, co1[9][0], co1[9][1],(255,0,0),3)
                        cv2.line(qwr, co1[11][0], co1[11][1],(255,0,0),3)
                        cv2.line(qwr, co1[8][0], co1[8][1],(255,0,0),3)
                        cv2.line(qwr, co1[10][0], co1[10][1],(255,0,0),3)
                        # In[9]:
                        cv2.imshow('r',qwr)
                        qwr2="stick/frame"+str(k)+".jpg"
                        qw1 = cv2.cvtColor(qwr, cv2.COLOR_BGR2GRAY)
                        qw2= cv2.cvtColor(qwr2, cv2.COLOR_BGR2GRAY)

                        fig = plt.figure("Images")
                        images = ("Original", qw1), ("Contrast", qw2)
                        for (i, (name, image)) in enumerate(images):
                                ax = fig.add_subplot(1, 3, i + 1)
                                ax.set_title(name)
                        plt.imshow(hash(tuple(image)))
                        # compare the images
                        s,m=compare_images(qw1, qw2, "Image1 vs Image2")
                        k+=1
                        sse=s
                        mse=m

                else:
                    break
    elapsed= time.time()-start_time
    #print("sse score : ", sse)
    print("Mean squared error : ", elapsed/100)
    cap.release()
    cv2.destroyAllWindows()
Ejemplo n.º 9
0
def video2posevideo(video_name):
    time_start = time.clock()

    import numpy as np

    sys.path.append(os.path.dirname(__file__) + "/../")

    from scipy.misc import imread, imsave

    from config import load_config
    from dataset.factory import create as create_dataset
    from nnet import predict
    from util import visualize
    from dataset.pose_dataset import data_to_input

    from multiperson.detections import extract_detections
    from multiperson.predict import SpatialModel, eval_graph, get_person_conf_multicut
    from multiperson.visualize import PersonDraw, visualize_detections

    import matplotlib.pyplot as plt

    from PIL import Image, ImageDraw, ImageFont
    font = ImageFont.truetype("./font/NotoSans-Bold.ttf", 24)

    import random

    cfg = load_config("demo/pose_cfg_multi.yaml")

    dataset = create_dataset(cfg)

    sm = SpatialModel(cfg)
    sm.load()

    draw_multi = PersonDraw()

    # Load and setup CNN part detector
    sess, inputs, outputs = predict.setup_pose_prediction(cfg)

    ################

    video = read_video(video_name)

    video_frame_number = int(video.duration * video.fps) ## duration: second / fps: frame per second
    video_frame_ciphers = math.ceil(math.log(video_frame_number, 10)) ## ex. 720 -> 3

    pose_frame_list = []

    point_r = 3 # radius of points
    point_min = 10 # threshold of points - If there are more than point_min points in person, we define he/she is REAL PERSON
    part_min = 3 # threshold of parts - If there are more than part_min parts in person, we define he/she is REAL PERSON / part means head, arm and leg
    point_num = 17 # There are 17 points in 1 person

    def ellipse_set(person_conf_multi, people_i, point_i):
        return (person_conf_multi[people_i][point_i][0] - point_r, person_conf_multi[people_i][point_i][1] - point_r, person_conf_multi[people_i][point_i][0] + point_r, person_conf_multi[people_i][point_i][1] + point_r)

    def line_set(person_conf_multi, people_i, point_i, point_j):
        return (person_conf_multi[people_i][point_i][0], person_conf_multi[people_i][point_i][1], person_conf_multi[people_i][point_j][0], person_conf_multi[people_i][point_j][1])

    def draw_ellipse_and_line(draw, person_conf_multi, people_i, a, b, c, point_color):
        draw.ellipse(ellipse_set(person_conf_multi, people_i, a), fill=point_color)
        draw.ellipse(ellipse_set(person_conf_multi, people_i, b), fill=point_color)
        draw.ellipse(ellipse_set(person_conf_multi, people_i, c), fill=point_color)
        draw.line(line_set(person_conf_multi, people_i, a, b), fill=point_color, width=5)
        draw.line(line_set(person_conf_multi, people_i, b, c), fill=point_color, width=5)

    for i in range(0, video_frame_number):
        image = video.get_frame(i/video.fps)

        ######################

        image_batch = data_to_input(image)

        # Compute prediction with the CNN
        outputs_np = sess.run(outputs, feed_dict={inputs: image_batch})
        scmap, locref, pairwise_diff = predict.extract_cnn_output(outputs_np, cfg, dataset.pairwise_stats)

        detections = extract_detections(cfg, scmap, locref, pairwise_diff)
        unLab, pos_array, unary_array, pwidx_array, pw_array = eval_graph(sm, detections)
        person_conf_multi = get_person_conf_multicut(sm, unLab, unary_array, pos_array)

        # print('person_conf_multi: ')
        # print(type(person_conf_multi))
        # print(person_conf_multi)

        # Add library to save image
        image_img = Image.fromarray(image)

        # Save image with points of pose
        draw = ImageDraw.Draw(image_img)

        people_num = 0
        people_real_num = 0
        people_part_num = 0

        people_num = person_conf_multi.size / (point_num * 2)
        people_num = int(people_num)
        print('people_num: ' + str(people_num))

        for people_i in range(0, people_num):
            point_color_r = random.randrange(0, 256)
            point_color_g = random.randrange(0, 256)
            point_color_b = random.randrange(0, 256)
            point_color = (point_color_r, point_color_g, point_color_b, 255)
            point_list = []
            point_count = 0
            point_i = 0 # index of points
            part_count = 0 # count of parts in THAT person

            # To find rectangle which include that people - list of points x, y coordinates
            people_x = []
            people_y = []

            for point_i in range(0, point_num):
                if person_conf_multi[people_i][point_i][0] + person_conf_multi[people_i][point_i][1] != 0: # If coordinates of point is (0, 0) == meaningless data
                    point_count = point_count + 1
                    point_list.append(point_i)

            # Draw each parts
            if (5 in point_list) and (7 in point_list) and (9 in point_list): # Draw left arm
                draw_ellipse_and_line(draw, person_conf_multi, people_i, 5, 7, 9, point_color)
                part_count = part_count + 1
            if (6 in point_list) and (8 in point_list) and (10 in point_list): # Draw right arm
                draw_ellipse_and_line(draw, person_conf_multi, people_i, 6, 8, 10, point_color)
                part_count = part_count + 1
            if (11 in point_list) and (13 in point_list) and (15 in point_list): # Draw left leg
                draw_ellipse_and_line(draw, person_conf_multi, people_i, 11, 13, 15, point_color)
                part_count = part_count + 1
            if (12 in point_list) and (14 in point_list) and (16 in point_list): # Draw right leg
                draw_ellipse_and_line(draw, person_conf_multi, people_i, 12, 14, 16, point_color)
                part_count = part_count + 1
            if point_count >= point_min:
                people_real_num = people_real_num + 1
                for point_i in range(0, point_num):
                    if person_conf_multi[people_i][point_i][0] + person_conf_multi[people_i][point_i][1] != 0: # If coordinates of point is (0, 0) == meaningless data
                        draw.ellipse(ellipse_set(person_conf_multi, people_i, point_i), fill=point_color)
                        people_x.append(person_conf_multi[people_i][point_i][0])
                        people_y.append(person_conf_multi[people_i][point_i][1])
                # Draw rectangle which include that people
                draw.rectangle([min(people_x), min(people_y), max(people_x), max(people_y)], fill=point_color, outline=5)


            if part_count >= part_min:
                people_part_num = people_part_num + 1

        draw.text((0, 0), 'People(by point): ' + str(people_real_num) + ' (threshold = ' + str(point_min) + ')', (0,0,0), font=font)
        draw.text((0, 32), 'People(by line): ' + str(people_part_num) + ' (threshold = ' + str(part_min) + ')', (0,0,0), font=font)
        draw.text((0, 64), 'Frame: ' + str(i) + '/' + str(video_frame_number), (0,0,0), font=font)
        draw.text((0, 96), 'Total time required: ' + str(round(time.clock() - time_start, 1)) + 'sec', (0,0,0))

        print('people_real_num: ' + str(people_real_num))
        print('people_part_num: ' + str(people_part_num))
        print('frame: ' + str(i))

        image_img_numpy = np.asarray(image_img)

        pose_frame_list.append(image_img_numpy)

    video_pose = ImageSequenceClip(pose_frame_list, fps=video.fps)
    video_pose.write_videofile("testset/" + video_name + "_pose.mp4", fps=video.fps)

    print("Time(s): " + str(time.clock() - time_start))
def test_net(visualise, cache_scoremaps, development):
    logging.basicConfig(level=logging.INFO)

    cfg = load_config()
    dataset = create_dataset(cfg)
    dataset.set_shuffle(False)

    sm = SpatialModel(cfg)
    sm.load()

    draw_multi = PersonDraw()

    from_cache = "cached_scoremaps" in cfg
    if not from_cache:
        sess, inputs, outputs = setup_pose_prediction(cfg)

    if cache_scoremaps:
        out_dir = cfg.scoremap_dir
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)

    pairwise_stats = dataset.pairwise_stats
    num_images = dataset.num_images if not development else min(10, dataset.num_images)
    coco_results = []

    for k in range(num_images):
        print('processing image {}/{}'.format(k, num_images-1))

        batch = dataset.next_batch()

        cache_name = "{}.mat".format(batch[Batch.data_item].coco_id)

        if not from_cache:
            outputs_np = sess.run(outputs, feed_dict={inputs: batch[Batch.inputs]})
            scmap, locref, pairwise_diff = extract_cnn_output(outputs_np, cfg, pairwise_stats)

            if cache_scoremaps:
                if visualise:
                    img = np.squeeze(batch[Batch.inputs]).astype('uint8')
                    pose = argmax_pose_predict(scmap, locref, cfg.stride)
                    arrows = argmax_arrows_predict(scmap, locref, pairwise_diff, cfg.stride)
                    visualize.show_arrows(cfg, img, pose, arrows)
                    visualize.waitforbuttonpress()
                    continue

                out_fn = os.path.join(out_dir, cache_name)
                dict = {'scoremaps': scmap.astype('float32'),
                        'locreg_pred': locref.astype('float32'),
                        'pairwise_diff': pairwise_diff.astype('float32')}
                scipy.io.savemat(out_fn, mdict=dict)
                continue
        else:
            #cache_name = '1.mat'
            full_fn = os.path.join(cfg.cached_scoremaps, cache_name)
            mlab = scipy.io.loadmat(full_fn)
            scmap = mlab["scoremaps"]
            locref = mlab["locreg_pred"]
            pairwise_diff = mlab["pairwise_diff"]

        detections = extract_detections(cfg, scmap, locref, pairwise_diff)
        unLab, pos_array, unary_array, pwidx_array, pw_array = eval_graph(sm, detections)
        person_conf_multi = get_person_conf_multicut(sm, unLab, unary_array, pos_array)

        if visualise:
            img = np.squeeze(batch[Batch.inputs]).astype('uint8')
            #visualize.show_heatmaps(cfg, img, scmap, pose)

            """
            # visualize part detections after NMS
            visim_dets = visualize_detections(cfg, img, detections)
            plt.imshow(visim_dets)
            plt.show()
            visualize.waitforbuttonpress()
            """

#            """
            visim_multi = img.copy()
            draw_multi.draw(visim_multi, dataset, person_conf_multi)

            plt.imshow(visim_multi)
            plt.show()
            visualize.waitforbuttonpress()
#            """


        if cfg.use_gt_segm:
            coco_img_results = pose_predict_with_gt_segm(scmap, locref, cfg.stride, batch[Batch.data_item].gt_segm,
                                                      batch[Batch.data_item].coco_id)
            coco_results += coco_img_results
            if len(coco_img_results):
                dataset.visualize_coco(coco_img_results, batch[Batch.data_item].visibilities)

    if cfg.use_gt_segm:
        with open('predictions_with_segm.json', 'w') as outfile:
            json.dump(coco_results, outfile)

    sess.close()