Exemple #1
0
def train_one_epoch(sess, ops, fetchworker, train_writer):
    loss_sum = []
    fetch_time = 0
    for batch_idx in range(fetchworker.num_batches):
        start = time.time()
        batch_input_data, batch_data_gt, radius = fetchworker.fetch()
        end = time.time()
        fetch_time += end - start
        feed_dict = {
            ops['pointclouds_pl']: batch_input_data,
            ops['pointclouds_gt']: batch_data_gt[:, :, 0:3],
            ops['pointclouds_gt_normal']: batch_data_gt[:, :, 0:3],
            ops['pointclouds_radius']: radius
        }
        summary, step, _, pred_val, gen_loss_emd = sess.run(
            [
                ops['pretrain_merged'], ops['step'], ops['pre_gen_train'],
                ops['pred'], ops['gen_loss_emd']
            ],
            feed_dict=feed_dict)
        train_writer.add_summary(summary, step)
        loss_sum.append(gen_loss_emd)

        if step % 30 == 0:
            pointclouds_image_input = pc_util.point_cloud_three_views(
                batch_input_data[0, :, 0:3])
            pointclouds_image_input = np.expand_dims(np.expand_dims(
                pointclouds_image_input, axis=-1),
                                                     axis=0)
            pointclouds_image_pred = pc_util.point_cloud_three_views(
                pred_val[0, :, :])
            pointclouds_image_pred = np.expand_dims(np.expand_dims(
                pointclouds_image_pred, axis=-1),
                                                    axis=0)
            pointclouds_image_gt = pc_util.point_cloud_three_views(
                batch_data_gt[0, :, 0:3])
            pointclouds_image_gt = np.expand_dims(np.expand_dims(
                pointclouds_image_gt, axis=-1),
                                                  axis=0)
            feed_dict = {
                ops['pointclouds_image_input']: pointclouds_image_input,
                ops['pointclouds_image_pred']: pointclouds_image_pred,
                ops['pointclouds_image_gt']: pointclouds_image_gt,
            }
            summary = sess.run(ops['image_merged'], feed_dict)
            train_writer.add_summary(summary, step)

    loss_sum = np.asarray(loss_sum)
    log_string('step: %d mean gen_loss_emd: %f\n' %
               (step, round(loss_sum.mean(), 4)))
    print 'read data time: %s mean gen_loss_emd: %f' % (round(
        fetch_time, 4), round(loss_sum.mean(), 4))
Exemple #2
0
    def train_one_epoch(self):
        loss_sum = []
        fetch_time = 0
        for batch_idx in range(self.fetchworker.num_batches):
            start = time.time()
            batch_data_input, batch_data_clean, batch_data_dist, batch_data_edgeface, radius,point_order = self.fetchworker.fetch()
            batch_data_edge = np.reshape(batch_data_edgeface[:,0:2*NUM_EDGE,:],(BATCH_SIZE,NUM_EDGE,6))
            batch_data_face = np.reshape(batch_data_edgeface[:, 2*NUM_EDGE:2*NUM_EDGE+3*NUM_FACE,:],(BATCH_SIZE, NUM_FACE, 9))
            # A = batch_data_face[:,:,3:6]-batch_data_face[:,:,0:3]
            # B = batch_data_face[:,:,6:9]-batch_data_face[:,:,0:3]
            # batch_data_normal = np.cross(A,B)+1e-12
            # batch_data_normal = batch_data_normal / np.sqrt(np.sum(batch_data_normal ** 2, axis=-1, keepdims=True))
            # batch_data_edgepoint =batch_data_edgeface[:, 2*NUM_EDGE+3*NUM_FACE:, :]
            end = time.time()
            fetch_time += end - start

            feed_dict = {self.pointclouds_input: batch_data_input,
                         self.pointclouds_idx: point_order,
                         self.pointclouds_edge: batch_data_edge,
                         self.pointclouds_surface: batch_data_face,
                         self.pointclouds_radius: radius}
            _, summary, step, pred_coord, pred_edgecoord, edgemask, edge_loss = self.sess.run(
                [self.gen_train, self.merged, self.step, self.pred_coord, self.pred_edgecoord, self.edgemask, self.edge_loss], feed_dict=feed_dict)
            self.train_writer.add_summary(summary, step)
            loss_sum.append(edge_loss)
            edgemask[:,0:5]=1
            pred_edgecoord = pred_edgecoord[0][edgemask[0]==1]
            if step % 30 == 0:
                pointclouds_image_input = pc_util.point_cloud_three_views(batch_data_input[0, :, 0:3])
                pointclouds_image_input = np.expand_dims(np.expand_dims(pointclouds_image_input, axis=-1), axis=0)
                pointclouds_image_pred = pc_util.point_cloud_three_views(pred_coord[0, :, 0:3])
                pointclouds_image_pred = np.expand_dims(np.expand_dims(pointclouds_image_pred, axis=-1), axis=0)
                pointclouds_image_gt = pc_util.point_cloud_three_views(pred_edgecoord[:, 0:3])
                pointclouds_image_gt = np.expand_dims(np.expand_dims(pointclouds_image_gt, axis=-1), axis=0)
                feed_dict = {self.pointclouds_image_input: pointclouds_image_input,
                             self.pointclouds_image_pred: pointclouds_image_pred,
                             self.pointclouds_image_gt: pointclouds_image_gt}
                summary = self.sess.run(self.image_merged, feed_dict)
                self.train_writer.add_summary(summary, step)
            if step % 100 ==0:
                loss_sum = np.asarray(loss_sum)
                log_string('step: %d edge_loss: %f\n' % (step, round(loss_sum.mean(), 4)))
                print 'datatime:%s edge_loss:%f' % (round(fetch_time, 4), round(loss_sum.mean(), 4))
                loss_sum = []
def write_result():
    root_path = "/home/lqyu/server/proj49/PointSR_data/test_data/our_collected_data"
    model_names = ['1024_nonormal_generator2_2', '1024_nonormal_generator2_2_uniformloss',
                   '1024_nonormal_generator2_2_recursive']

    index_path = os.path.join("index.html")
    index = open(index_path, "w")
    index.write("<html><body><table><tr>")
    index.write("<th width='5%%'>name</th>")

    index.write("<tr><th></th>")
    for model in model_names:
        index.write("<th>%s</th>" % model)
    index.write("</tr>")

    # get sample list
    items = os.listdir(root_path + "/" + model_names[0])
    items.sort()

    # mkdir model image path
    for model in model_names:
        if not os.path.exists(root_path + "/" + model + "_three_view_img/"):
            os.makedirs(root_path + "/" + model + "_three_view_img/")

    # write img to file
    for item in tqdm(items):
        index.write("<tr>")
        index.write("<td>%s</td>" % item)

        # write prediction
        for model in model_names:
            path = root_path + "/" + model +"/" + item
            if not os.path.exists(path):
                continue
            img_path = root_path + "/" + model + "_three_view_img/" + item
            img_path = img_path.replace("xyz", "png")
            if not os.path.exists(img_path):
                data = np.loadtxt(path)
                data = data[:, 0:3]
                img = pc_util.point_cloud_three_views(data, diameter=8)
                imsave(img_path, img)
            index.write("<td><img width='100%%', src='%s'></td>" % img_path)
        index.write("</tr>")
    index.close()
Exemple #4
0
    def test(self, show=False, use_normal=False):
        data_folder = '../../PointSR_data/CAD/mesh_MC16k'
        phase = data_folder.split('/')[-2] + data_folder.split('/')[-1]
        save_path = os.path.join(MODEL_DIR, 'result/' + phase)
        self.saver = tf.train.Saver()
        _, restore_model_path = model_utils.pre_load_checkpoint(MODEL_DIR)
        print restore_model_path

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        with tf.Session(config=config) as sess:
            self.saver.restore(sess, restore_model_path)
            samples = glob(data_folder + "/.xyz")
            samples.sort()
            total_time = 0

            #input, dist, edge, data_radius, name = data_provider.load_patch_data(NUM_POINT, True, 30)
            #edge = np.reshape(edge,[-1,NUM_EDGE,6])

            for i, item in tqdm(enumerate(samples)):
                input = np.loadtxt(item)
                edge = np.loadtxt(
                    item.replace('mesh_MC16k',
                                 'mesh_edge').replace('.xyz', '_edge.xyz'))
                idx = np.all(edge[:, 0:3] == edge[:, 3:6], axis=-1)
                edge = edge[idx == False]
                l = len(edge)
                idx = range(l) * (1300 / l) + list(
                    np.random.permutation(l)[:1300 % l])
                edge = edge[idx]

                # # coord = input[:, 0:3]
                # # centroid = np.mean(coord, axis=0, keepdims=True)
                # # coord = coord - centroid
                # # furthest_distance = np.amax(np.sqrt(np.sum(abs(coord) ** 2, axis=-1)))
                # # coord = coord / furthest_distance
                # # input[:, 0:3] = coord
                input = np.expand_dims(input, axis=0)
                # input = data_provider.jitter_perturbation_point_cloud(input, sigma=0.01, clip=0.02)

                start_time = time.time()
                edge_pl = tf.placeholder(tf.float32, [1, edge.shape[0], 6])
                dist_gt_pl = tf.sqrt(
                    tf.reduce_min(model_utils.distance_point2edge(
                        self.pred, edge_pl),
                                  axis=-1))

                pred, pred_dist, dist_gt = sess.run(
                    [self.pred, self.pred_dist, dist_gt_pl],
                    feed_dict={
                        self.pointclouds_input: input[:, :, 0:3],
                        self.pointclouds_radius: np.ones(BATCH_SIZE),
                        edge_pl: np.expand_dims(edge, axis=0)
                    })
                total_time += time.time() - start_time
                norm_pl = np.zeros_like(pred)
                ##--------------visualize predicted point cloud----------------------
                if show:
                    f, axis = plt.subplots(3)
                    axis[0].imshow(
                        pc_util.point_cloud_three_views(input[:, 0:3],
                                                        diameter=5))
                    axis[1].imshow(
                        pc_util.point_cloud_three_views(pred[0, :, :],
                                                        diameter=5))
                    axis[2].imshow(
                        pc_util.point_cloud_three_views(gt[:, 0:3],
                                                        diameter=5))
                    plt.show()

                path = os.path.join(save_path,
                                    item.split('/')[-1][:-4] + ".ply")
                # rgba =data_provider.convert_dist2rgba(pred_dist2,scale=10)
                # data_provider.save_ply(path, np.hstack((pred[0, ...],rgba,pred_dist2.reshape(NUM_ADDPOINT,1))))

                path = os.path.join(save_path,
                                    item.split('/')[-1][:-4] + "_gt.ply")
                rgba = data_provider.convert_dist2rgba(dist_gt[0], scale=5)
                data_provider.save_ply(
                    path,
                    np.hstack(
                        (pred[0, ...], rgba, dist_gt.reshape(NUM_ADDPOINT,
                                                             1))))

                path = path.replace(phase, phase + "_input")
                path = path.replace('xyz', 'ply')
                rgba = data_provider.convert_dist2rgba(pred_dist[0], scale=5)
                data_provider.save_ply(
                    path,
                    np.hstack((input[0], rgba, pred_dist.reshape(NUM_POINT,
                                                                 1))))
            print total_time / len(samples)
def write_result2html_benchmark():
    root_path = "/home/lqyu/server/proj49/PointSR_data/test_data/our_collected_data"
    phase = 'surface_benchmark'
    input_path ="../data/"+phase+"/1024_nonuniform"
    gt_path = "../data/"+phase+"/4096"
    model_names = ['1024_nonormal_generator2_2','1024_nonormal_generator2_2_uniformloss','1024_nonormal_generator2_2_recursive']


    index_path = os.path.join(root_path, phase + "_index.html")
    index = open(index_path, "w")
    index.write("<html><body><table><tr>")
    index.write("<th width='5%%'>name</th><th>Input</th>")
    index.write("<th>Refered GT</th></tr>")

    index.write("<tr><th></th>")
    for model in model_names:
        index.write("<th>%s</th>" % model)
    index.write("</tr>")

    # get sample list
    items = os.listdir(root_path + "/" + model_names[0] + "/result/" + phase)
    items.sort()

    # mkdir model image path
    for model in model_names:
        if not os.path.exists(root_path + "/" + model + "/result/" + phase + "_three_view_img/"):
            os.makedirs(root_path + "/" + model + "/result/" + phase + "_three_view_img/")

    # write img to file
    for item in tqdm(items):
        index.write("<tr>")
        index.write("<td>%s</td>" % item)

        # write input image
        object = item.split("_")[0]
        id = item.split(".")[0]
        path = input_path + "/%s.xyz" % (id)
        img_path = input_path + "_three_view_img/%s.png" % (id)
        if not os.path.exists(input_path + "_three_view_img/"):
            os.makedirs(input_path + "_three_view_img/")
        if not os.path.exists(img_path):
            data = np.loadtxt(path)
            data = data[:, 0:3]
            img = pc_util.point_cloud_three_views(data,diameter=8)
            imsave(img_path, img)
        index.write("<td><img width='100%%', src='%s'></td>" % img_path)
        # write gt image
        path = gt_path + "/%s.xyz" % (id)
        img_path = gt_path + "_three_view_img/%s.png" % (id)
        if not os.path.exists(gt_path + "_three_view_img/"):
            os.makedirs(gt_path + "_three_view_img/")
        if not os.path.exists(img_path):
            data = np.loadtxt(path)
            data = data[:, 0:3]
            img = pc_util.point_cloud_three_views(data,diameter=8)
            imsave(img_path, img)
        index.write("<td><img width='100%%', src='%s'></td>" % img_path)
        index.write("</tr>")

        index.write("<tr><th></th>")
        # write prediction
        for model in model_names:
            path = root_path + "/" + model + "/result/" + phase + "/" + item
            if not os.path.exists(path):
                continue
            img_path = root_path + "/" + model + "/result/" + phase + "_three_view_img/" + item
            img_path = img_path.replace("xyz", "png")
            if not os.path.exists(img_path):
                data = np.loadtxt(path)
                data = data[:, 0:3]
                img = pc_util.point_cloud_three_views(data,diameter=8)
                imsave(img_path, img)
            index.write("<td><img width='100%%', src='%s'></td>" % img_path)
        index.write("</tr>")
    index.close()
def write_result2html_ModelNet():
    root_path = "../model"
    gt_path = "../data/ModelNet10_poisson_normal"
    #gt_path = "../data/Patches"
    model_names = ['1024_generator2_2','new_1024_generator2_2','new_1024_generator2_2_fixed_lr']
    phase = 'test'

    index_path = os.path.join(root_path, phase + "_index.html")
    index = open(index_path, "w")
    index.write("<html><body><table><tr>")
    index.write("<th width='5%%'>name</th><th>Input</th>")
    index.write("<th>Refered GT</th></tr>")

    index.write("<tr><th></th>")
    for model in model_names:
        index.write("<th>%s</th>" % model)
    index.write("</tr>")

    # get sample list
    items = os.listdir(root_path + "/" + model_names[0] + "/result/" + phase)
    items.sort()

    # mkdir model image path
    for model in model_names:
        if not os.path.exists(root_path + "/" + model + "/result/" + phase + "_three_view_img/"):
            os.makedirs(root_path + "/" + model + "/result/" + phase + "_three_view_img/")

    # write img to file
    for item in tqdm(items[::25]):
        index.write("<tr>")
        index.write("<td>%s</td>" % item)

        # write input image
        object = item.split("_")[0]
        id = item.split(".")[0]
        fixed = "%s/1024_nonuniform/%s" % (gt_path, 'train')
        path = fixed + "/%s.xyz" % (id)
        img_path = fixed + "_three_view_img/%s.png" % (id)
        if not os.path.exists(fixed + "_three_view_img/"):
            os.makedirs(fixed + "_three_view_img/")
        if not os.path.exists(img_path):
            data = np.loadtxt(path)
            data = data[:, 0:3]
            img = pc_util.point_cloud_three_views(data,diameter=8)
            imsave(img_path, img)
        index.write("<td><img width='100%%', src='%s'></td>" % img_path)
        # write gt image
        fixed = "%s/4096/%s" % (gt_path, 'train')
        path = fixed + "/%s.xyz" % (id)
        img_path = fixed + "_three_view_img/%s.png" % (id)
        if not os.path.exists(fixed + "_three_view_img/"):
            os.makedirs(fixed + "_three_view_img/")
        if not os.path.exists(img_path):
            data = np.loadtxt(path)
            data = data[:, 0:3]
            img = pc_util.point_cloud_three_views(data,diameter=8)
            imsave(img_path, img)
        index.write("<td><img width='100%%', src='%s'></td>" % img_path)
        index.write("</tr>")

        index.write("<tr><th></th>")
        # write prediction
        for model in model_names:
            path = root_path + "/" + model + "/result/" + phase + "/" + item
            if not os.path.exists(path):
                continue
            img_path = root_path + "/" + model + "/result/" + phase + "_three_view_img/" + item
            img_path = img_path.replace("xyz", "png")
            if not os.path.exists(img_path):
                data = np.loadtxt(path)
                data = data[:, 0:3]
                img = pc_util.point_cloud_three_views(data,diameter=8)
                imsave(img_path, img)
            index.write("<td><img width='100%%', src='%s'></td>" % img_path)
        index.write("</tr>")
    index.close()
Exemple #7
0
def prediction_whole_model(data_folder=None, show=False, use_normal=False):
    data_folder = '../data/test_data/our_collected_data/MC_5k'
    phase = data_folder.split('/')[-2] + data_folder.split('/')[-1]
    save_path = os.path.join(MODEL_DIR, 'result/' + phase)

    if not os.path.exists(save_path):
        os.makedirs(save_path)
    samples = glob(data_folder + "/*.xyz")
    samples.sort(reverse=True)
    input = np.loadtxt(samples[0])

    if use_normal:
        pointclouds_ipt = tf.placeholder(tf.float32,
                                         shape=(1, input.shape[0], 6))
    else:
        pointclouds_ipt = tf.placeholder(tf.float32,
                                         shape=(1, input.shape[0], 3))
    pred, _ = MODEL_GEN.get_gen_model(pointclouds_ipt,
                                      is_training=False,
                                      scope='generator',
                                      bradius=1.0,
                                      reuse=None,
                                      use_normal=use_normal,
                                      use_bn=False,
                                      use_ibn=False,
                                      bn_decay=0.95,
                                      up_ratio=UP_RATIO)
    saver = tf.train.Saver()
    _, restore_model_path = model_utils.pre_load_checkpoint(MODEL_DIR)
    print restore_model_path

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    with tf.Session(config=config) as sess:
        saver.restore(sess, restore_model_path)
        samples = glob(data_folder + "/*.xyz")
        samples.sort()
        total_time = 0
        for i, item in enumerate(samples):
            input = np.loadtxt(item)
            gt = input

            # input = data_provider.jitter_perturbation_point_cloud(np.expand_dims(input,axis=0),sigma=0.003,clip=0.006)
            input = np.expand_dims(input, axis=0)

            if not use_normal:
                input = input[:, :, 0:3]
                gt = gt[:, 0:3]
            print item, input.shape

            start_time = time.time()
            pred_pl = sess.run(pred, feed_dict={pointclouds_ipt: input})
            total_time += time.time() - start_time
            norm_pl = np.zeros_like(pred_pl)

            ##--------------visualize predicted point cloud----------------------
            path = os.path.join(save_path, item.split('/')[-1])
            if show:
                f, axis = plt.subplots(3)
                axis[0].imshow(
                    pc_util.point_cloud_three_views(input[0, :, 0:3],
                                                    diameter=5))
                axis[1].imshow(
                    pc_util.point_cloud_three_views(pred_pl[0, :, :],
                                                    diameter=5))
                axis[2].imshow(
                    pc_util.point_cloud_three_views(gt[:, 0:3], diameter=5))
                plt.show()
            data_provider.save_pl(
                path, np.hstack((pred_pl[0, ...], norm_pl[0, ...])))
            path = path[:-4] + '_input.xyz'
            data_provider.save_pl(path, input[0])
        print total_time / 20
Exemple #8
0
def eval_one_epoch(sess, ops, num_votes=1, topk=1):
    error_cnt = 0
    is_training = False
    total_correct = 0
    total_seen = 0
    loss_sum = 0
    total_seen_class = [0 for _ in range(NUM_CLASSES)]
    total_correct_class = [0 for _ in range(NUM_CLASSES)]
    fout = open(os.path.join(DUMP_DIR, 'pred_label.txt'), 'w')
    for fn in range(len(TEST_FILES)):
        log_string('----' + str(fn) + '----')
        current_data, current_label = provider.loadDataFile(TEST_FILES[fn])
        current_data = current_data[:, 0:NUM_POINT, :]
        current_label = np.squeeze(current_label)
        print(current_data.shape)

        file_size = current_data.shape[0]
        num_batches = file_size // BATCH_SIZE
        print(file_size)

        for batch_idx in range(num_batches):
            start_idx = batch_idx * BATCH_SIZE
            end_idx = (batch_idx + 1) * BATCH_SIZE
            cur_batch_size = end_idx - start_idx

            # Aggregating BEG
            batch_loss_sum = 0  # sum of losses for the batch
            batch_pred_sum = np.zeros(
                (cur_batch_size, NUM_CLASSES))  # score for classes
            batch_pred_classes = np.zeros(
                (cur_batch_size, NUM_CLASSES))  # 0/1 for classes
            for vote_idx in range(num_votes):
                rotated_data = provider.rotate_point_cloud_by_angle(
                    current_data[start_idx:end_idx, :, :],
                    vote_idx / float(num_votes) * np.pi * 2)
                feed_dict = {
                    ops['pointclouds_pl']: rotated_data,
                    ops['labels_pl']: current_label[start_idx:end_idx],
                    ops['is_training_pl']: is_training
                }
                loss_val, pred_val = sess.run([ops['loss'], ops['pred']],
                                              feed_dict=feed_dict)
                batch_pred_sum += pred_val
                batch_pred_val = np.argmax(pred_val, 1)
                for el_idx in range(cur_batch_size):
                    batch_pred_classes[el_idx, batch_pred_val[el_idx]] += 1
                batch_loss_sum += (loss_val * cur_batch_size /
                                   float(num_votes))
            # pred_val_topk = np.argsort(batch_pred_sum, axis=-1)[:,-1*np.array(range(topk))-1]
            # pred_val = np.argmax(batch_pred_classes, 1)
            pred_val = np.argmax(batch_pred_sum, 1)
            # Aggregating END

            correct = np.sum(pred_val == current_label[start_idx:end_idx])
            # correct = np.sum(pred_val_topk[:,0:topk] == label_val)
            total_correct += correct
            total_seen += cur_batch_size
            loss_sum += batch_loss_sum

            for i in range(start_idx, end_idx):
                l = current_label[i]
                total_seen_class[l] += 1
                total_correct_class[l] += (pred_val[i - start_idx] == l)
                fout.write('%d, %d\n' % (pred_val[i - start_idx], l))

                if pred_val[
                        i -
                        start_idx] != l and FLAGS.visu:  # ERROR CASE, DUMP!
                    img_filename = '%d_label_%s_pred_%s.jpg' % (
                        error_cnt, SHAPE_NAMES[l],
                        SHAPE_NAMES[pred_val[i - start_idx]])
                    img_filename = os.path.join(DUMP_DIR, img_filename)
                    output_img = pc_util.point_cloud_three_views(
                        np.squeeze(current_data[i, :, :]))
                    scipy.misc.imsave(img_filename, output_img)
                    error_cnt += 1

    log_string('eval mean loss: %f' % (loss_sum / float(total_seen)))
    log_string('eval accuracy: %f' % (total_correct / float(total_seen)))
    log_string('eval avg class acc: %f' % (np.mean(
        np.array(total_correct_class) /
        np.array(total_seen_class, dtype=np.float))))

    class_accuracies = np.array(total_correct_class) / np.array(
        total_seen_class, dtype=np.float)
    for i, name in enumerate(SHAPE_NAMES):
        log_string('%10s:\t%0.3f' % (name, class_accuracies[i]))
Exemple #9
0
    def train_one_epoch(self):
        loss_sum = []
        fetch_time = 0
        for batch_idx in range(self.fetchworker.num_batches):
            start = time.time()
            batch_data_input, batch_data_clean, batch_data_dist, batch_data_edgeface, radius, point_order = self.fetchworker.fetch(
            )
            batch_data_edge = np.reshape(
                batch_data_edgeface[:, 0:2 * NUM_EDGE, :],
                (BATCH_SIZE, NUM_EDGE, 6))
            batch_data_face = np.reshape(
                batch_data_edgeface[:, 2 * NUM_EDGE:2 * NUM_EDGE +
                                    3 * NUM_FACE, :],
                (BATCH_SIZE, NUM_FACE, 9))
            batch_data_edgepoint = batch_data_edgeface[:, 2 * NUM_EDGE +
                                                       3 * NUM_FACE:, :]
            end = time.time()
            fetch_time += end - start

            feed_dict = {
                self.pointclouds_input: batch_data_input,
                self.pointclouds_poisson: batch_data_clean,
                # self.pointclouds_dist: batch_data_dist,
                self.pointclouds_idx: point_order,
                self.pointclouds_edge: batch_data_edge,
                self.pointclouds_plane: batch_data_face,
                self.pointclouds_radius: radius
            }

            _, summary, step, pred_coord, pred_edgecoord, edgemask, gen_loss_emd = self.sess.run(
                [
                    self.pre_gen_train, self.merged, self.step,
                    self.pred_coord, self.pred_edgecoord, self.edgemask,
                    self.edge_loss
                ],
                feed_dict=feed_dict)
            self.train_writer.add_summary(summary, step)
            loss_sum.append(gen_loss_emd)
            edgemask[:, 0:5] = 1
            pred_edgecoord = pred_edgecoord[0][edgemask[0] == 1]
            if step % 30 == 0:
                pointclouds_image_input = pc_util.point_cloud_three_views(
                    batch_data_input[0, :, 0:3])
                pointclouds_image_input = np.expand_dims(np.expand_dims(
                    pointclouds_image_input, axis=-1),
                                                         axis=0)
                pointclouds_image_pred = pc_util.point_cloud_three_views(
                    pred_coord[0, :, 0:3])
                pointclouds_image_pred = np.expand_dims(np.expand_dims(
                    pointclouds_image_pred, axis=-1),
                                                        axis=0)
                pointclouds_image_gt = pc_util.point_cloud_three_views(
                    pred_edgecoord[:, 0:3])
                pointclouds_image_gt = np.expand_dims(np.expand_dims(
                    pointclouds_image_gt, axis=-1),
                                                      axis=0)
                feed_dict = {
                    self.pointclouds_image_input: pointclouds_image_input,
                    self.pointclouds_image_pred: pointclouds_image_pred,
                    self.pointclouds_image_gt: pointclouds_image_gt
                }

                summary = self.sess.run(self.image_merged, feed_dict)
                self.train_writer.add_summary(summary, step)
        loss_sum = np.asarray(loss_sum)
        log_string('step: %d mean gen_loss_emd: %f\n' %
                   (step, round(loss_sum.mean(), 4)))
        print 'read data time: %s mean gen_loss_emd: %f' % (round(
            fetch_time, 4), round(loss_sum.mean(), 4))