def __init__(self, n_joints, height, width, n_filter):
     self.n_joints = n_joints
     self.height = height
     self.width = width
     self.n_filter = n_filter
     self.motor_input, self.gt_image, self.predicted_image, self.predicted_error, self.weight_error_loss, self.loss =\
         self.create_network(self.n_joints, self.height, self.width, n_filter)
     self.m_normalizer = Normalizer(low=-1, high=1)
     self.s_normalizer = Normalizer(
         low=0, high=1, min_data=0,
         max_data=1)  # equal identity here as pixels are already in [0,1]
     self.saver = tf.train.Saver()
     self.fig = plt.figure(1, figsize=(14, 8))
示例#2
0
def processFile(l):
    
    js_file_path = l[0]
    
    if js_file_path in seen:
        return (js_file_path, None, 'Skipped')
    
    pid = int(multiprocessing.current_process().ident)
    
    # Temp files to be created during processing
    temp_files = {'path_tmp': 'tmp_%d.js' % pid,
                  'path_tmp_b': 'tmp_%d.b.js' % pid,
                  'path_tmp_b_n': 'tmp_%d.b.n.js' % pid,
                  'path_tmp_u': 'tmp_%d.u.js' % pid,
                  'path_tmp_u_n': 'tmp_%d.u.n.js' % pid,
                  'path_tmp_b_a': 'tmp_%d.b.a.js' % pid,
                  'path_tmp_u_a': 'tmp_%d.u.a.js' % pid}
    
    try:        
        # Strip comments, replace literals, etc
        try:
            prepro = Preprocessor(os.path.join(corpus_root, js_file_path))
            prepro.write_temp_file(temp_files['path_tmp'])
        except:
            cleanup(temp_files)
            return (js_file_path, None, 'Preprocessor fail')
        
        
        # Pass through beautifier to fix layout:
        # - once through JSNice without renaming
#         jsNiceBeautifier = JSNice(flags=['--no-types', '--no-rename'])
#         
#         (ok, _out, _err) = jsNiceBeautifier.run(temp_files['path_tmp'], 
#                                                 temp_files['path_tmp_b_n'])
#         if not ok:
#             cleanup(temp_files)
#             return (js_file_path, None, 'JSNice Beautifier fail')
        
        
#         # - and another time through uglifyjs pretty print only 
#         clear = Beautifier()
#         ok = clear.run(temp_files['path_tmp_b_n'], 
#                        temp_files['path_tmp_b'])
#         if not ok:
#             cleanup(temp_files)
#             return (js_file_path, None, 'Beautifier fail')
        
#         # JSNice is down! 
        clear = Beautifier()
        ok = clear.run(temp_files['path_tmp'], 
                       temp_files['path_tmp_b_n'])
        if not ok:
            cleanup(temp_files)
            return (js_file_path, None, 'Beautifier fail')
        # Normalize
        norm = Normalizer()
        ok = norm.run(os.path.join(os.path.dirname(os.path.realpath(__file__)), 
                                 temp_files['path_tmp_b_n']),
                      False, 
                      temp_files['path_tmp_b'])
        if not ok:
            cleanup(temp_files)
            return (js_file_path, None, 'Normalizer fail')
        
        
        
        # Minify
        ugly = Uglifier()
        ok = ugly.run(temp_files['path_tmp_b'], 
                      temp_files['path_tmp_u_n'])
        if not ok:
            cleanup(temp_files)
            return (js_file_path, None, 'Uglifier fail')
        # Normalize
        norm = Normalizer()
        ok = norm.run(os.path.join(os.path.dirname(os.path.realpath(__file__)), 
                                 temp_files['path_tmp_u_n']),
                      False, 
                      temp_files['path_tmp_u'])
        if not ok:
            cleanup(temp_files)
            return (js_file_path, None, 'Normalizer fail')
        
        
        
        # Num tokens before vs after
        try:
            tok_clear = Lexer(temp_files['path_tmp_b']).tokenList
            tok_ugly = Lexer(temp_files['path_tmp_u']).tokenList
        except:
            cleanup(temp_files)
            return (js_file_path, None, 'Lexer fail')
        
        # For now only work with minified files that have
        # the same number of tokens as the originals
        if not len(tok_clear) == len(tok_ugly):
            cleanup(temp_files)
            return (js_file_path, None, 'Num tokens mismatch')
        
        
        # Align minified and clear files, in case the beautifier 
        # did something weird
        try:
            aligner = Aligner()
            # This is already the baseline corpus, no (smart) renaming yet
            aligner.align(temp_files['path_tmp_b'], 
                          temp_files['path_tmp_u'])
        except:
            cleanup(temp_files)
            return (js_file_path, None, 'Aligner fail')
        
        try:
            lex_clear = Lexer(temp_files['path_tmp_b_a'])
            iBuilder_clear = IndexBuilder(lex_clear.tokenList)
            
            lex_ugly = Lexer(temp_files['path_tmp_u_a'])
            iBuilder_ugly = IndexBuilder(lex_ugly.tokenList)
        except:
            cleanup(temp_files)
            return (js_file_path, None, 'IndexBuilder fail')
        
        
        
        # Normalize
        norm = Normalizer()
        ok = norm.run(os.path.join(os.path.dirname(os.path.realpath(__file__)), 
                                 temp_files['path_tmp_b']),
                      True, 
                      temp_files['path_tmp_u_n'])
        if not ok:
            cleanup(temp_files)
            return (js_file_path, None, 'Normalizer fail')
        
        try:
            lex_norm = Lexer(temp_files['path_tmp_u_n'])
            iBuilder_norm = IndexBuilder(lex_norm.tokenList)
        except:
            cleanup(temp_files)
            return (js_file_path, None, 'IndexBuilder fail')
        
        normalized = []
        for line_idx, line in enumerate(iBuilder_norm.tokens):
            normalized.append(' '.join([t for (_tt,t) in line]) + "\n")
        
        
        
        # Compute scoping: name2scope is a dictionary where keys
        # are (name, start_index) tuples and values are scope identifiers. 
        # Note: start_index is a flat (unidimensional) index, 
        # not a (line_chr_idx, col_chr_idx) index.
        try:
            scopeAnalyst = ScopeAnalyst(os.path.join(
                                 os.path.dirname(os.path.realpath(__file__)), 
                                 temp_files['path_tmp_u_a']))
#             _name2defScope = scopeAnalyst.resolve_scope()
#             _isGlobal = scopeAnalyst.isGlobal
#             _name2useScope = scopeAnalyst.resolve_use_scope()
        except:
            cleanup(temp_files)
            return (js_file_path, None, 'ScopeAnalyst fail')
        
        orig = []
        no_renaming = []
        
        for line_idx, line in enumerate(iBuilder_ugly.tokens):
            orig.append(' '.join([t for (_tt,t) in \
                                  iBuilder_clear.tokens[line_idx]]) + "\n")
            
            no_renaming.append(' '.join([t for (_tt,t) in line]) + "\n")
            
#         # Simple renaming: disambiguate overloaded names using scope id
        basic_renaming = renameUsingScopeId(scopeAnalyst, 
                                            iBuilder_ugly)
        
        # More complicated renaming: collect the context around  
        # each name (global variables, API calls, punctuation)
        # and build a hash of the concatenation.
#         hash_renaming = renameUsingHashAllPrec(scopeAnalyst, 
#                                                 iBuilder_ugly,
#                                                 debug=True)
        
        hash_def_one_renaming = renameUsingHashDefLine(scopeAnalyst, 
                                                   iBuilder_ugly, 
                                                   twoLines=False,
                                                   debug=False)

        hash_def_two_renaming = renameUsingHashDefLine(scopeAnalyst, 
                                                    iBuilder_ugly, 
                                                    twoLines=True,
                                                    debug=False)

        cleanup(temp_files)
        return (js_file_path,
                orig, 
                no_renaming, 
                basic_renaming,
                normalized, 
#                 hash_renaming,
                hash_def_one_renaming,
                hash_def_two_renaming)
        
    except Exception, e:
        cleanup(temp_files)
        return (js_file_path, None, str(e))
示例#3
0
def generate_video(dir_model="model/trained", n_samples=2000, dir_video="temp/video"):
    """
    Generate of video of the estimated body image by randomly and smoothly moving in the motor space.

    Parameters:
        dir_model - model directory
        n_samples - number of samples in the motor space
        dir_video - directory where to save the video
    """

    # check the video directory
    if os.path.exists(dir_video):
        ans = input("> The folder {} already exists; do you want to overwrite its content? [y,n]: ".format(dir_video))
        if ans is not "y":
            print("exiting the program")
            return
    if not os.path.exists(dir_video):
        os.makedirs(dir_video)

    # normalize the pixel channels in [0, 1] and subsample the dataset
    s_normalizer = Normalizer(low=0, high=1, min_data=0, max_data=1)  # identity mapping in this case, as the pixel values are already in [0, 1]

    # load the network
    saver, motor_input, net_predicted_image, net_predicted_error = load_network(dir_model)

    # get parameters
    n_joints = motor_input.get_shape()[1].value
    height = net_predicted_image.get_shape()[1].value
    width = net_predicted_image.get_shape()[2].value

    # create a background checkerboard
    checkerboard = create_checkerboard(height, width)

    # create a smooth trajectory in the motor space
    n_anchors = n_samples//40
    anchors = 2 * np.random.rand(n_anchors, n_joints) - 1
    trajectory = np.full((n_samples, n_joints), np.nan)
    for k in range(4):
        tck = interpolate.splrep(np.linspace(0, 1, n_anchors), anchors[:, k])
        trajectory[:, k] = interpolate.splev(np.linspace(0, 1, n_samples), tck)

    # prepare the video writer
    video = cv2.VideoWriter(filename=dir_video + "/video.avi", fourcc=cv2.VideoWriter_fourcc(*'XVID'), fps=24, frameSize=(800, 600))

    # prepare the figure
    fig = plt.figure(figsize=(8, 6))
    ax0 = fig.add_subplot(231, projection="3d")
    ax1 = fig.add_subplot(234, projection="3d")
    ax2 = fig.add_subplot(232)
    ax3 = fig.add_subplot(233)
    ax4 = fig.add_subplot(235)
    ax5 = fig.add_subplot(236)

    with tf.Session() as sess:

        # reload the network's variable values
        saver.restore(sess, tf.train.latest_checkpoint(dir_model + "/"))

        for k in range(n_samples):

            print("\rframe {}".format(k, end=""))

            # get the motor input
            curr_motor = trajectory[[k], :]

            # predict image
            predicted_image = sess.run(net_predicted_image, feed_dict={motor_input: curr_motor})[0]
            predicted_image = s_normalizer.reconstruct(predicted_image)  # identity mapping in this case, as the pixel values are already in [0, 1]

            # predict error
            predicted_error = sess.run(net_predicted_error, feed_dict={motor_input: curr_motor})[0]

            # build mask
            predicted_mask = (predicted_error <= 0.056).astype(float)

            # build the masked image
            alpha_channel = np.mean(predicted_mask, axis=2)
            transparent_masked_predicted_image = np.dstack((predicted_image * predicted_mask, alpha_channel))

            # display the motor configuration with a trace
            ax0.cla()
            ax0.set_title("motor space $(m_1, m_2, m_3)$")
            ax0.plot(trajectory[max(0, k - 48):k, 0], trajectory[max(0, k - 48):k, 1], trajectory[max(0, k - 48):k, 2], 'b-')
            ax0.plot(trajectory[k - 1:k, 0], trajectory[k - 1:k, 1], trajectory[k - 1:k, 2], 'ro')
            ax0.set_xlim(-1, 1)
            ax0.set_ylim(-1, 1)
            ax0.set_zlim(-1, 1)
            ax0.set_xticklabels([])
            ax0.set_yticklabels([])
            ax0.set_zticklabels([])
            #
            ax1.cla()
            ax1.set_title("motor space $(m_2, m_3, m_4)$")
            ax1.plot(trajectory[max(0, k - 48):k, 1], trajectory[max(0, k - 48):k, 2], trajectory[max(0, k - 48):k, 3], 'b-')
            ax1.plot(trajectory[k - 1:k, 1], trajectory[k - 1:k, 2], trajectory[k - 1:k, 3], 'ro')
            ax1.set_xlim(-1, 1)
            ax1.set_ylim(-1, 1)
            ax1.set_zlim(-1, 1)
            ax1.set_xticklabels([])
            ax1.set_yticklabels([])
            ax1.set_zticklabels([])

            # display the predicted image
            ax2.cla()
            ax2.set_title("predicted image")
            ax2.imshow(predicted_image)
            ax2.axis("off")
            #
            ax3.cla()
            ax3.set_title("predicted error")
            ax3.imshow(predicted_error)
            ax3.axis("off")
            #
            ax4.cla()
            ax4.set_title("predicted mask")
            ax4.imshow(predicted_mask)
            ax4.axis("off")
            #
            ax5.cla()
            ax5.set_title("masked prediction")
            ax5.imshow(checkerboard)
            ax5.imshow(transparent_masked_predicted_image)
            ax5.axis("off")

            plt.show(block=False)
            fig.savefig(dir_video + "/img.png")
            plt.pause(0.001)

            # write frame
            image = cv2.imread(dir_video + "/img.png")
            video.write(image)

    # clean up
    cv2.destroyAllWindows()
    video.release()
    os.remove(dir_video + "/img.png")
示例#4
0
def explore_joint_space(dir_model="model/trained", motor_input_ref=None):
    """
    Regularly sample each dimension of the motor space and display the generated body image.

    Parameters:
        dir_model - model directory
        index - indexes (if list) or number of random indexes (if int) of samples to reconstruct
        motor_input_ref - reference motor input from which to explore the motor space
    """

    # load the network
    saver, motor_input, net_predicted_image, net_predicted_error = load_network(dir_model)

    # get parameters
    n_joints = motor_input.get_shape()[1].value

    # generate the reference motor input if necessary
    if motor_input_ref is None:
        motor_input_ref = np.zeros((1, n_joints))

    # create the sensory normalizer
    s_normalizer = Normalizer(low=0, high=1, min_data=0, max_data=1)  # identity mapping in this case, as the pixel values are already in [0, 1]

    # display the reconstructions
    with tf.Session() as sess:

        # reload the network's variable values
        saver.restore(sess, tf.train.latest_checkpoint(dir_model + "/"))

        # iterate over the motor dimensions
        for joint in range(n_joints):

            # create a figure
            fig = plt.figure(figsize=(12, 6))

            for index, val in enumerate(np.linspace(-1, 1, 6)):

                # variation to add to the reference motor input
                delta = [[val if i == joint else 0. for i in range(n_joints)]]

                # predict image
                predicted_image = sess.run(net_predicted_image, feed_dict={motor_input: motor_input_ref + delta})[0]
                predicted_image = s_normalizer.reconstruct(predicted_image)  # identity mapping in this case, as the pixel values are already in [0, 1]

                # predict error
                predicted_error = sess.run(net_predicted_error, feed_dict={motor_input: motor_input_ref + delta})[0]

                # display
                ax1 = fig.add_subplot(2, 6, 1 + index)
                ax2 = fig.add_subplot(2, 6, 7 + index)
                #
                fig.suptitle('joint {}'.format(joint), fontsize=12)
                #
                ax1.set_title("predicted image")
                ax1.imshow(predicted_image)
                ax1.axis("off")
                #
                ax2.set_title("predicted error")
                ax2.imshow(predicted_error)
                ax2.axis("off")
            #
            # fig.savefig(".temp/exploration/joint_{}.svg".format(joint))

    plt.show(block=False)
    plt.pause(0.001)
示例#5
0
def fit_gmm(dir_green_dataset="dataset/generated/green", dir_model="model/trained", indexes=100):
    """
    Fit a 2-Gaussian Mixture Model to the predicted prediction error distribution  over the whole dataset
    to distinguish the pixels belonging to the body image from the ones belonging to the background.

    Parameters:
        dir_dataset - dataset directory
        dir_model - model directory
        index - indexes (if list) or number of random indexes (if int) of samples to reconstruct
    """

    # load the dataset
    m, _, n_samples, _, _, _, _ = load_data(dir_green_dataset)

    # draw indexes if necessary
    if type(indexes) == int:
        indexes = np.random.choice(n_samples, indexes)

    # normalize the motor_input configuration in [-1, 1] and subsample the dataset
    m_normalizer = Normalizer(low=-1, high=1)
    m = m_normalizer.fit_transform(m)
    m = m[indexes, :]

    # load the network
    saver, motor_input, _, predicted_error = load_network(dir_model)

    # initialize list
    all_pred_errors = []

    # stack all the predicted prediction errors over the selected set of motor samples
    with tf.Session() as sess:

        # reload the network's variable values
        saver.restore(sess, tf.train.latest_checkpoint(dir_model + "/"))

        for i, ind in enumerate(indexes):

            # predict error
            curr_error = sess.run(predicted_error, feed_dict={motor_input: m[[i], :]})
            curr_error = curr_error[0]

            # append errors
            all_pred_errors = all_pred_errors + list(curr_error.flatten())

    # fit a 2-GMM model
    all_pred_errors = np.array(all_pred_errors).reshape(-1, 1)
    gmm_model = mixture.GaussianMixture(n_components=2, n_init=5)
    gmm_model.fit(all_pred_errors)

    # find the intersection of the two gaussians
    x = np.linspace(-0.05, 0.3, 1000).reshape(-1, 1)
    lp = gmm_model.score_samples(x)  # log probability
    p = gmm_model.predict_proba(x)  # class prediction
    diff = np.abs(p[:, 0] - p[:, 1])
    cross_index = np.argmin(diff)
    threshold = x[cross_index, 0]

    print("Estimated error threshold: {:.3f}".format(threshold))

    # display the histogram and optimizes gaussians
    fig = plt.figure()
    ax = fig.add_subplot(111)
    #
    ax.hist(all_pred_errors[:, 0], bins=100, normed=True, color="blue", rwidth=0.8, label="errors")
    ax.plot(x, np.exp(lp), 'r-', label="GMM")
    ax.legend(loc="upper left")
    #
    ax2 = ax.twinx()
    ax2.plot(x, p[:, 0], 'c--', label="Proba comp 1")
    ax2.plot(x, p[:, 1], 'g--', label="Proba comp 2")
    ax2.set_ylim([0, 1.2])
    ax2.legend(loc="upper right")
    #
    #fig.savefig(".temp/fitted_GMM/gmm.svg")
    #
    plt.show(block=False)
    plt.pause(0.001)

    return threshold
示例#6
0
def reconstruct_data(dir_model="model/trained", dir_dataset="dataset/generated/combined", indexes=11):
    """
    Test a network by reconstructing samples from the dataset.

    Parameters:
        dir_model - model directory
        dir_dataset - dataset directory
        index - indexes (if list) or number of random indexes (if int) of samples to reconstruct
    """

    # load the dataset
    m, s, n_samples, height, width, n_channels, n_joints = load_data(dir_dataset)

    # draw indexes if necessary
    if type(indexes) == int:
        indexes = np.random.choice(n_samples, indexes)

    # normalize the motor_input configuration in [-1, 1] and subsample the dataset
    m_normalizer = Normalizer(low=-1, high=1)
    m = m_normalizer.fit_transform(m)
    m = m[indexes, :]

    # normalize the pixel channels in [0, 1] and subsample the dataset
    s_normalizer = Normalizer(low=0, high=1, min_data=0, max_data=1)  # identity mapping in this case, as the pixel values are already in [0, 1]
    s = s_normalizer.transform(s)
    s = s[indexes, :]

    # load the network
    saver, motor_input, net_predicted_image, net_predicted_error = load_network(dir_model)

    # create a background checkerboard
    checkerboard = create_checkerboard(height, width)

    # create a figure
    fig = plt.figure(figsize=(18, 7))
    fig.suptitle('samples {}'.format(indexes), fontsize=12)

    # display the reconstructions
    with tf.Session() as sess:

        # reload the network's variable values
        saver.restore(sess, tf.train.latest_checkpoint(dir_model + "/"))

        for i, ind in enumerate(indexes):

            # ground truth image
            gt_green_image = s[i, :, :, :]

            # predict image
            predicted_image = sess.run(net_predicted_image, feed_dict={motor_input: m[[i], :]})[0]
            predicted_image = s_normalizer.reconstruct(predicted_image)  # identity mapping in this case, as the pixel values are already in [0, 1]

            # predict error
            predicted_error = sess.run(net_predicted_error, feed_dict={motor_input: m[[i], :]})[0]

            # build mask
            predicted_mask = (predicted_error <= 0.056).astype(float)

            # build the masked image
            alpha_channel = np.mean(predicted_mask, axis=2)
            transparent_masked_predicted_image = np.dstack((predicted_image * predicted_mask, alpha_channel))

            # display
            ax1 = fig.add_subplot(5, len(indexes), 0*len(indexes) + 1 + i)
            ax2 = fig.add_subplot(5, len(indexes), 1*len(indexes) + 1 + i)
            ax3 = fig.add_subplot(5, len(indexes), 2*len(indexes) + 1 + i)
            ax4 = fig.add_subplot(5, len(indexes), 3*len(indexes) + 1 + i)
            ax5 = fig.add_subplot(5, len(indexes), 4*len(indexes) + 1 + i)
            #
            ax1.set_title("ground-truth image")
            ax1.imshow(gt_green_image)
            ax1.axis("off")
            #
            ax2.set_title("predicted image")
            ax2.imshow(predicted_image)
            ax2.axis("off")
            #
            ax3.set_title("predicted error")
            ax3.imshow(predicted_error)
            ax3.axis("off")
            #
            ax4.set_title("mask")
            ax4.imshow(predicted_mask)
            ax4.axis("off")
            #
            ax5.set_title('masked predicted image')
            ax5.imshow(checkerboard)
            ax5.imshow(transparent_masked_predicted_image)
            ax5.axis("off")
        #
        # fig.savefig(".temp/reconstruction/reconstructions.svg")

    plt.show(block=False)
    plt.pause(0.001)
示例#7
0
def evaluate_body_image(dir_model="model/trained", dir_green_dataset="dataset/generated/green", indexes=6):
    """
    Test the body image mask generated by a network by comparing it the ground-truth green-background dataset.

    Parameters:
        dir_model - model directory
        dir_green_dataset - green-background dataset directory
        index - indexes (if list) or number of random indexes (if int) of samples to reconstruct
    """

    # load the dataset
    m, s, n_samples, height, width, n_channels, n_joints = load_data(dir_green_dataset)

    # draw indexes if necessary
    if type(indexes) == int:
        indexes = np.random.choice(n_samples, indexes)

    # normalize the motor_input configuration in [-1, 1] and subsample the dataset
    m_normalizer = Normalizer(low=-1, high=1)
    m = m_normalizer.fit_transform(m)

    # normalize the pixel channels in [0, 1] and subsample the dataset
    s_normalizer = Normalizer(low=0, high=1, min_data=0, max_data=1)  # identity mapping in this case, as the pixel values are already in [0, 1]
    s = s_normalizer.transform(s)

    # load the network
    saver, motor_input, net_predicted_image, net_predicted_error = load_network(dir_model)

    # create a background checkerboard
    checkerboard = create_checkerboard(height, width)

    # track all matches over the dataset
    all_iou_body = []
    all_appearance_match = []

    # create figure
    fig = plt.figure(figsize=(9, 10))
    fig.suptitle('samples {}'.format(indexes), fontsize=12)

    with tf.Session() as sess:

        # reload the network's variable values
        saver.restore(sess, tf.train.latest_checkpoint(dir_model + "/"))

        # track the number of displayed indexes
        i = 0

        # compute the mask and appearance matches over the whole dataset
        for ind in range(n_samples):

            # ground-truth image
            gt_green_image = s[ind, :, :, :]  # image with green background - [height, width, 3] in [0, 1]

            # ground-truth body mask
            gt_body_mask = ((gt_green_image[:, :, 0] == 0) & (abs(gt_green_image[:, :, 1] - 141/255) <= 1e-3) & (gt_green_image[:, :, 2] == 0)).astype(float)
            gt_body_mask = 1 - np.repeat(gt_body_mask[:, :, np.newaxis], 3, axis=2)  # ground-truth body mask - [height, width, 3] in (0., 1.)

            # predicted image - [height, width, 3] in [0, 1+]
            predicted_image = sess.run(net_predicted_image, feed_dict={motor_input: m[[ind], :]})[0]
            predicted_image = s_normalizer.reconstruct(predicted_image)  # identity mapping in this case, as the pixel values are already in [0, 1]

            # predicted error - [height, width, 3] in [0, 1+]
            predicted_error = sess.run(net_predicted_error, feed_dict={motor_input: m[[ind], :]})[0]

            # predicted body mask
            predicted_body_mask = (predicted_error <= 0.056).astype(float)  # [height, width, 3] in (0., 1.)

            # evaluation of the predicted mask: Intersection over Union
            intersection = np.logical_and(gt_body_mask, predicted_body_mask)
            union = np.logical_or(gt_body_mask, predicted_body_mask)
            iou_body_mask = np.sum(intersection) / np.sum(union) if np.sum(union) > 0 else 1

            # error in the predicted image
            error_image = gt_green_image - predicted_image  # [height, width, 3] in [0., 1.+]

            # evaluation of the body appearance: mean error under the intersection of masks
            masked_image_error = error_image * intersection
            appearance_match = 1 - np.sum(np.abs(masked_image_error)) / np.sum(intersection) if np.sum(intersection) > 0 else 1

            # creation of the mask images with transparency for display
            alpha_channel = np.mean(intersection, axis=2)
            transparent_masked_gt_image = np.dstack((gt_green_image * intersection, alpha_channel))
            transparent_masked_predicted_image = np.dstack((predicted_image * intersection, alpha_channel))

            # store the matches and scores
            all_iou_body.append(iou_body_mask)
            all_appearance_match.append(appearance_match)

            # display the matches for the selected indexes
            if ind in indexes:

                # display
                ax1 = fig.add_subplot(6, len(indexes), 0*len(indexes) + 1 + i)
                ax2 = fig.add_subplot(6, len(indexes), 1*len(indexes) + 1 + i)
                ax3 = fig.add_subplot(6, len(indexes), 2*len(indexes) + 1 + i)
                ax4 = fig.add_subplot(6, len(indexes), 3*len(indexes) + 1 + i)
                ax5 = fig.add_subplot(6, len(indexes), 4*len(indexes) + 1 + i)
                ax6 = fig.add_subplot(6, len(indexes), 5*len(indexes) + 1 + i)
                #
                i = i + 1
                #
                ax1.set_title("ground-truth body mask")
                ax1.imshow(np.where(gt_body_mask == 1., 1., gt_green_image))
                ax1.axis("off")
                #
                ax2.set_title("predicted mask")
                ax2.imshow(predicted_body_mask)
                ax2.axis("off")
                #
                ax3.set_title("mask error: {:.2f}%".format(100 * iou_body_mask), fontsize=11)
                ax3.imshow((gt_body_mask - predicted_body_mask) / 2 + 0.5)
                ax3.axis("off")
                #
                ax4.set_title("masked ground-truth")
                ax4.imshow(checkerboard)
                ax4.imshow(transparent_masked_gt_image)
                ax4.axis("off")
                #
                ax5.set_title("masked prediction")
                ax5.imshow(checkerboard)
                ax5.imshow(transparent_masked_predicted_image)
                ax5.axis("off")
                #
                ax6.set_title("appearance error: {:.2f}%".format(100 * appearance_match), fontsize=11)
                ax6.imshow(checkerboard)
                ax6.imshow(masked_image_error / 2 + 0.5)
                ax6.axis("off")
                #
                #fig.savefig(".temp/mask_and_appearance_match/evaluation.svg".format(ind))

    # print the stats
    print("mask match = {mean} +/- {std}".format(mean=np.mean(all_iou_body), std=np.std(all_iou_body)))
    print("appearance match = {mean} +/- {std}".format(mean=np.mean(all_appearance_match), std=np.std(all_appearance_match)))

    plt.show(block=False)
    plt.pause(0.001)
class SensoriMotorPredictionNetwork:
    def __init__(self, n_joints, height, width, n_filter):
        self.n_joints = n_joints
        self.height = height
        self.width = width
        self.n_filter = n_filter
        self.motor_input, self.gt_image, self.predicted_image, self.predicted_error, self.weight_error_loss, self.loss =\
            self.create_network(self.n_joints, self.height, self.width, n_filter)
        self.m_normalizer = Normalizer(low=-1, high=1)
        self.s_normalizer = Normalizer(
            low=0, high=1, min_data=0,
            max_data=1)  # equal identity here as pixels are already in [0,1]
        self.saver = tf.train.Saver()
        self.fig = plt.figure(1, figsize=(14, 8))

    @staticmethod
    def create_network(n_joints, h, w, n_filter=32):
        """
            Create the network for sensorimotor prediction.
            Given an input motor configuration, the network outputs a predictive image and predicted prediction error.

            Parameters:
                n_joints - dimension of the motor states
                h - height of the output image
                w - width of the output image
                n_filter - maximal number of convolution filters
            """
        # todo: test padding="same" for the final convolutional layers
        # todo: test with batch normalization

        # reset the default graph
        tf.reset_default_graph()

        # create placeholders
        motor_input = tf.placeholder(dtype=tf.float32,
                                     shape=[None, n_joints],
                                     name="motor_input")
        gt_image = tf.placeholder(dtype=tf.float32,
                                  shape=[None, h, w, 3],
                                  name="gt_image")
        weight_error_loss = tf.placeholder(dtype=tf.float32,
                                           shape=[],
                                           name="weight_error_loss")

        # dense mapping to larger layers
        with tf.name_scope("dense_expand") as scope:
            out = tf.layers.dense(inputs=motor_input,
                                  units=8 * 8 * 3,
                                  activation=tf.nn.selu,
                                  name="layer1")
            out = tf.layers.dense(inputs=out,
                                  units=round(h / 5) * round(w / 5) * 3,
                                  activation=tf.nn.selu,
                                  name="layer2")

        # reshaping
        out = tf.reshape(out,
                         shape=[-1, round(h / 5),
                                round(w / 5), 3],
                         name="reshaping")

        # branch 1: image - deconvolution is done by upsampling + convolution - upsampling with +2 to compensate for the valid padding
        with tf.variable_scope("image_branch") as scope:
            img = tf.image.resize_images(
                out,
                size=(round(h / 4) + 2, round(w / 4) + 2),
                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
                align_corners=True)
            img = tf.layers.conv2d(inputs=img,
                                   filters=n_filter,
                                   kernel_size=(3, 3),
                                   padding='valid',
                                   activation=tf.nn.selu,
                                   name="deconv1")
            #
            img = tf.image.resize_images(
                img,
                size=(round(h / 2) + 2, round(w / 2) + 2),
                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
                align_corners=True)
            img = tf.layers.conv2d(inputs=img,
                                   filters=n_filter,
                                   kernel_size=(3, 3),
                                   padding='valid',
                                   activation=tf.nn.selu,
                                   name="deconv2")
            #
            img = tf.image.resize_images(
                img,
                size=(h + 8, w + 8),
                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
                align_corners=True)
            img = tf.layers.conv2d(inputs=img,
                                   filters=n_filter,
                                   kernel_size=(3, 3),
                                   padding='valid',
                                   activation=tf.nn.selu,
                                   name="deconv3")
            #
            # convolutions + reducing the number of filters to 3 channels
            img = tf.layers.conv2d(inputs=img,
                                   filters=n_filter / 2,
                                   kernel_size=(3, 3),
                                   padding="valid",
                                   activation=tf.nn.selu,
                                   name="conv1")
            img = tf.layers.conv2d(inputs=img,
                                   filters=n_filter / 4,
                                   kernel_size=(3, 3),
                                   padding="valid",
                                   activation=tf.nn.selu,
                                   name="conv2")
            img = tf.layers.conv2d(inputs=img,
                                   filters=3,
                                   kernel_size=(3, 3),
                                   padding="valid",
                                   activation=tf.nn.relu,
                                   name="predicted_image")

        # branch 2 - prediction error
        with tf.variable_scope("error_branch") as scope:
            err = tf.image.resize_images(
                out,
                size=(round(h / 4) + 2, round(w / 4) + 2),
                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
                align_corners=True)
            err = tf.layers.conv2d(inputs=err,
                                   filters=n_filter,
                                   kernel_size=(3, 3),
                                   padding='valid',
                                   activation=tf.nn.selu,
                                   name="deconv1")
            #
            err = tf.image.resize_images(
                err,
                size=(round(h / 2) + 2, round(w / 2) + 2),
                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
                align_corners=True)
            err = tf.layers.conv2d(inputs=err,
                                   filters=n_filter,
                                   kernel_size=(3, 3),
                                   padding='valid',
                                   activation=tf.nn.selu,
                                   name="deconv2")
            #
            err = tf.image.resize_images(
                err,
                size=(h + 8, w + 8),
                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
                align_corners=True)
            err = tf.layers.conv2d(inputs=err,
                                   filters=n_filter,
                                   kernel_size=(3, 3),
                                   padding='valid',
                                   activation=tf.nn.selu,
                                   name="deconv3")
            #
            # convolutions + reducing the number of filters to 3 channels
            err = tf.layers.conv2d(inputs=err,
                                   filters=n_filter / 2,
                                   kernel_size=(3, 3),
                                   padding="valid",
                                   activation=tf.nn.selu,
                                   name="conv1")
            err = tf.layers.conv2d(inputs=err,
                                   filters=n_filter / 4,
                                   kernel_size=(3, 3),
                                   padding="valid",
                                   activation=tf.nn.selu,
                                   name="conv2")
            err = tf.layers.conv2d(inputs=err,
                                   filters=3,
                                   kernel_size=(3, 3),
                                   padding="valid",
                                   activation=tf.nn.relu,
                                   name="predicted_error")

        # define the loss
        with tf.name_scope("losses_computation") as scope:
            errors_image = tf.abs(tf.subtract(img, gt_image),
                                  name="errors_images")
            loss_reconstruction = tf.reduce_mean(errors_image,
                                                 name="loss_reconstruction")
            errors_mask = tf.abs(tf.subtract(err, errors_image),
                                 name="errors_mask")
            loss_mask = tf.reduce_mean(errors_mask, name="loss_error")
            loss_mask = tf.multiply(weight_error_loss,
                                    loss_mask,
                                    name="weighted_loss_error")
            loss = tf.add(loss_reconstruction, loss_mask, name="loss")

        return motor_input, gt_image, img, err, weight_error_loss, loss,

    def save_network(self):
        self.saver.save(
            tf.get_default_session(), dir_model + "/network.ckpt"
        )  # add global_step=global_step to not overwrite the previous model

    def train(self,
              m,
              s,
              dir_model="model/trained",
              n_epochs=int(5e4),
              batch_size=100):
        """
            Train the network.
            The error-predition component of the loss is weighted with an weight increasing from 0 to 1 during the first half of training.

            Parameters:
                m - motor data
                s - sensor data (images)
                dir_model - directory where to save the model
                n_epochs - number of mini-batch iterations
                batch_size - mini-batch size
            """

        # check the model directory
        if os.path.exists(dir_model):
            ans = input(
                "> The folder {} already exists; do you want to overwrite its content? [y,n]: "
                .format(dir_model))
            if ans is not "y":
                print("exiting the program")
                return

        # create directories if necessary
        if not os.path.exists(dir_model):
            os.makedirs(dir_model)
        dir_progress = dir_model + "/progress"
        if not os.path.exists(dir_progress):
            os.makedirs(dir_progress)

        # get the number of samples
        n_samples = s.shape[0]

        # normalize the motor_input configuration in [-1, 1]
        m = self.m_normalizer.fit_transform(m)

        # normalize the pixel channels in [0, 1] (doesn't change anything in this case, as plt.imread already outputs values in [0, 1]
        s = self.s_normalizer.transform(s)

        # define the optimizer
        global_step = tf.Variable(0, trainable=False)
        learning_rate = tf.train.polynomial_decay(1e-3,
                                                  global_step,
                                                  n_epochs,
                                                  1e-5,
                                                  power=1)
        # define the optimizer
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        minimize_op = optimizer.minimize(self.loss, global_step=global_step)

        # define the weighting of the error_loss - ramp up from 0 to 1 during the first half of training
        weight_err = tf.train.polynomial_decay(0.,
                                               global_step,
                                               n_epochs // 2,
                                               1.,
                                               power=1)

        # train the network
        print("training the network...")
        with tf.Session() as sess:

            # initialize the network
            sess.run(tf.global_variables_initializer())

            for epoch in range(n_epochs):

                # draw batch indexes
                indexes = np.random.choice(n_samples, batch_size, replace=True)

                # minimize the loss
                curr_weight = sess.run(weight_err)
                curr_loss, _, curr_lr = sess.run(
                    [self.loss, minimize_op, learning_rate],
                    feed_dict={
                        self.motor_input: m[indexes, :],
                        self.gt_image: s[indexes, :, :],
                        self.weight_error_loss: curr_weight
                    })

                if (epoch % max(1, np.round(n_epochs / 100))
                        == 0) or (epoch == n_epochs - 1):

                    print(
                        "epoch: {} ({:3.0f}%), learning rate: {:.4e}, error loss weight: {:.4e}, loss: {:.4e}"
                        .format(epoch, epoch / n_epochs * 100, curr_lr,
                                curr_weight, curr_loss))

                    # visualize one output
                    curr_image, curr_error = sess.run(
                        [self.predicted_image, self.predicted_error],
                        feed_dict={self.motor_input: m[[indexes[0]], :]})
                    curr_image = self.s_normalizer.reconstruct(curr_image)
                    binary_mask = (curr_error[0] < 0.056).astype(
                        float
                    )  # the htreshold value could be estimated on the fly with a GMM
                    self.display_figure(s[indexes[0]], curr_image[0],
                                        curr_error[0], binary_mask)

                    # save the visualization
                    self.save_figure(dir_progress, epoch)

                    # save the network
                    self.save_network()

        print("training finished.")

    def display_figure(self, gt_image, pred_image, pred_error, mask):
        """
            Display the output of the network for one input sample.

            Parameters:
                gt_image - ground truth image
                pred_image - predicted image
                pred_error - predicted prediction error
                mask - estimated mask
            """

        if not plt.fignum_exists(1):
            self.fig = plt.figure(1, figsize=(14, 8))

        # clean the figure
        plt.clf()
        ax1 = self.fig.add_subplot(231)
        ax2 = self.fig.add_subplot(232)
        ax3 = self.fig.add_subplot(234)
        ax4 = self.fig.add_subplot(235)
        ax5 = self.fig.add_subplot(133)

        checkerboard = create_checkerboard(pred_image.shape[0],
                                           pred_image.shape[1])

        ax1.cla()
        ax1.set_title("ground-truth image")
        ax1.imshow(gt_image)
        ax1.axis("off")

        ax2.cla()
        ax2.set_title("predicted image")
        ax2.imshow(pred_image)
        ax2.axis("off")

        ax3.cla()
        ax3.set_title("predicted prediction error")
        ax3.imshow(pred_error)
        ax3.axis("off")

        ax4.cla()
        ax4.set_title('mask')
        ax4.imshow(mask)
        ax4.axis("off")

        ax5.cla()
        alpha_channel = np.mean(mask, axis=2)
        transparent_masked_predicted_image = np.dstack(
            (pred_image * mask, alpha_channel))
        ax5.set_title('masked prediction')
        ax5.imshow(checkerboard)
        ax5.imshow(transparent_masked_predicted_image)
        ax5.axis("off")

        plt.show(block=False)
        plt.pause(1e-8)

    def save_figure(self, path, epoch):
        self.fig.savefig(path + "/epoch_{:06d}.png".format(epoch))