def initTF():
    global tf_session, ops
    with tf.device("/gpu:" + str(GPU_INDEX)):
        pointclouds_pl, labels_pl, _ = model.placeholder_inputs(1, NUM_POINT)
        print(tf.shape(pointclouds_pl))
        is_training_pl = tf.placeholder(tf.bool, shape=())

        pred, _ = model.get_model(pointclouds_pl,
                                  is_training_pl,
                                  NUM_CLASSES,
                                  hyperparams=PARAMS)

        # Add ops to save and restore all the variables.
        saver = tf.train.Saver()

    # Create a session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.log_device_placement = False
    tf_session = tf.Session(config=config)

    # Restore variables from disk.
    saver.restore(tf_session, CHECKPOINT)
    print("Model restored.")

    ops = {
        "pointclouds_pl": pointclouds_pl,
        "is_training_pl": is_training_pl,
        "pred": pred
    }
예제 #2
0
def predict():

    is_training = False

    with tf.device('/gpu:' + str(gpu_to_use)):
        is_training_ph = tf.placeholder(tf.bool, shape=())

        pointclouds_ph, ptsseglabel_ph, ptsseglabel_onehot_ph, ptsgroup_label_ph, _, _, _ = \
            model.placeholder_inputs(BATCH_SIZE, POINT_NUM, NUM_GROUPS, NUM_CATEGORY)

        net_output = model.get_model(pointclouds_ph,
                                     is_training_ph,
                                     group_cate_num=NUM_CATEGORY)

        group_mat_label = tf.matmul(
            ptsgroup_label_ph,
            tf.transpose(ptsgroup_label_ph,
                         perm=[0, 2,
                               1]))  #BxNxN: (i,j) if i and j in the same group

    # Add ops to save and restore all the variables.

    saver = tf.train.Saver()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True

    with tf.Session(config=config) as sess:

        flog = open(os.path.join(OUTPUT_DIR, 'log.txt'), 'w')

        # Restore variables from disk.
        ckptstate = tf.train.get_checkpoint_state(PRETRAINED_MODEL_PATH)
        if ckptstate is not None:
            LOAD_MODEL_FILE = os.path.join(
                PRETRAINED_MODEL_PATH,
                os.path.basename(ckptstate.model_checkpoint_path))
            saver.restore(sess, LOAD_MODEL_FILE)
            printout(flog, "Model loaded in file: %s" % LOAD_MODEL_FILE)
        else:
            printout(flog,
                     "Fail to load modelfile: %s" % PRETRAINED_MODEL_PATH)

        total_acc = 0.0
        total_seen = 0

        ious = np.zeros(NUM_CATEGORY)
        totalnums = np.zeros(NUM_CATEGORY)

        tpsins = [
            [] for itmp in range(NUM_CATEGORY)
        ]  #= np.array([]).reshape(0, NUM_CATEGORY)#np.zeros(NUM_CATEGORY)#
        fpsins = [
            [] for itmp in range(NUM_CATEGORY)
        ]  #= np.array([]).reshape(0, NUM_CATEGORY)#np.zeros(NUM_CATEGORY)#

        positive_ins_sgpn = np.zeros(NUM_CATEGORY)
        total_sgpn = np.zeros(NUM_CATEGORY)
        at = 0.25

        for shape_idx in range(len(TEST_DATASET)):
            cur_data, cur_seg, cur_group, cur_smpw = get_test_batch(
                TEST_DATASET, shape_idx)
            printout(flog, '%d / %d ...' % (shape_idx, len(TEST_DATASET)))

            seg_output = np.zeros_like(cur_seg)
            segrefine_output = np.zeros_like(cur_seg)
            group_output = np.zeros_like(cur_group)
            conf_output = np.zeros_like(cur_group).astype(np.float)

            pts_group_label, _ = model.convert_groupandcate_to_one_hot(
                cur_group)
            pts_label_one_hot = model.convert_seg_to_one_hot(cur_seg)
            num_data = cur_data.shape[0]

            gap = 5e-3
            volume_num = int(1. / gap) + 1
            volume = -1 * np.ones([volume_num, volume_num, volume_num]).astype(
                np.int32)
            volume_seg = -1 * np.ones([volume_num, volume_num, volume_num
                                       ]).astype(np.int32)

            intersections = np.zeros(NUM_CATEGORY)
            unions = np.zeros(NUM_CATEGORY)

            for j in range(num_data):
                print("Processsing: Shape [%d] Block[%d]" % (shape_idx, j))

                pts = cur_data[j, ...]

                feed_dict = {
                    pointclouds_ph:
                    np.expand_dims(pts, 0),
                    ptsseglabel_onehot_ph:
                    np.expand_dims(pts_label_one_hot[j, ...], 0),
                    ptsseglabel_ph:
                    np.expand_dims(cur_seg[j, ...], 0),
                    ptsgroup_label_ph:
                    np.expand_dims(pts_group_label[j, ...], 0),
                    is_training_ph:
                    is_training,
                }

                pts_corr_val0, pred_confidence_val0, ptsclassification_val0, pts_corr_label_val0 = \
                    sess.run([net_output['simmat'],
                              net_output['conf'],
                              net_output['semseg'],
                              group_mat_label],
                              feed_dict=feed_dict)

                seg = cur_seg[j, ...]
                ins = cur_group[j, ...]

                pts_corr_val = np.squeeze(pts_corr_val0[0])  #NxG
                pred_confidence_val = np.squeeze(pred_confidence_val0[0])
                ptsclassification_val = np.argmax(np.squeeze(
                    ptsclassification_val0[0]),
                                                  axis=1)

                seg = np.squeeze(seg)

                #print(label_bin)
                groupids_block, refineseg, group_seg = GroupMerging(
                    pts_corr_val, pred_confidence_val, ptsclassification_val,
                    label_bin
                )  # yolo_to_groupt(pts_corr_val, pts_corr_label_val0[0], seg,t=5)

                groupids = BlockMerging(volume, volume_seg, pts[:, 6:],
                                        groupids_block.astype(np.int32),
                                        group_seg, gap)

                seg_output[j, :] = ptsclassification_val
                group_output[j, :] = groupids
                conf_output[j, :] = pred_confidence_val
                total_acc += float(np.sum(ptsclassification_val == seg)
                                   ) / ptsclassification_val.shape[0]
                total_seen += 1

            ###### Evaluation
            ### Instance Segmentation
            ## Pred
            group_pred = group_output.reshape(-1)
            seg_pred = seg_output.reshape(-1)
            seg_gt = cur_seg.reshape(-1)
            conf_pred = conf_output.reshape(-1)
            pts = cur_data.reshape([-1, 9])

            # filtering
            x = (pts[:, 6] / gap).astype(np.int32)
            y = (pts[:, 7] / gap).astype(np.int32)
            z = (pts[:, 8] / gap).astype(np.int32)
            for i in range(group_pred.shape[0]):
                if volume[x[i], y[i], z[i]] != -1:
                    group_pred[i] = volume[x[i], y[i], z[i]]

            un = np.unique(group_pred)
            pts_in_pred = [[] for itmp in range(NUM_CATEGORY)]
            conf_in_pred = [[] for itmp in range(NUM_CATEGORY)]
            group_pred_final = -1 * np.ones_like(group_pred)
            grouppred_cnt = 0

            for ig, g in enumerate(un):  #each object in prediction
                if g == -1:
                    continue
                tmp = (group_pred == g)
                sem_seg_g = int(stats.mode(seg_pred[tmp])[0])
                if np.sum(tmp) > 0.25 * min_num_pts_in_group[sem_seg_g]:
                    conf_tmp = conf_pred[tmp]

                    pts_in_pred[sem_seg_g] += [tmp]
                    conf_in_pred[sem_seg_g].append(np.average(conf_tmp))
                    group_pred_final[tmp] = grouppred_cnt
                    grouppred_cnt += 1

            if False:
                pc_util.write_obj_color(
                    pts[:, :3], seg_pred.astype(np.int32),
                    os.path.join(OUTPUT_DIR, '%d_segpred.obj' % (shape_idx)))
                pc_util.write_obj_color(
                    pts[:, :3], group_pred_final.astype(np.int32),
                    os.path.join(OUTPUT_DIR, '%d_grouppred.obj' % (shape_idx)))
            '''
            # write to file
            cur_train_filename = TEST_DATASET.get_filename(shape_idx)
            scene_name = cur_train_filename
            counter = 0
            f_scene = open(os.path.join('output', scene_name + '.txt'), 'w')
            for i_sem in range(NUM_CATEGORY):
                for ins_pred, ins_conf in zip(pts_in_pred[i_sem], conf_in_pred[i_sem]):
                    f_scene.write('{}_{:03d}.txt {} {}\n'.format(os.path.join('output', 'pred_insts', scene_name), counter, i_sem, ins_conf))
                    with open(os.path.join('output', 'pred_insts', '{}_{:03}.txt'.format(scene_name, counter)), 'w') as f:
                        for i_ins in ins_pred:
                            if i_ins:
                                f.write('1\n')
                            else:
                                f.write('0\n')
                    counter += 1
            f_scene.close()

            # write_to_mesh
            mesh_filename = os.path.join('mesh', scene_name +'.ply')
            pc_util.write_ply(pts, mesh_filename)
            '''

            # GT
            group_gt = cur_group.reshape(-1)
            un = np.unique(group_gt)
            pts_in_gt = [[] for itmp in range(NUM_CATEGORY)]
            for ig, g in enumerate(un):
                tmp = (group_gt == g)
                sem_seg_g = int(stats.mode(seg_pred[tmp])[0])
                pts_in_gt[sem_seg_g] += [tmp]
                total_sgpn[sem_seg_g] += 1

            for i_sem in range(NUM_CATEGORY):
                tp = [0.] * len(pts_in_pred[i_sem])
                fp = [0.] * len(pts_in_pred[i_sem])
                gtflag = np.zeros(len(pts_in_gt[i_sem]))

                for ip, ins_pred in enumerate(pts_in_pred[i_sem]):
                    ovmax = -1.

                    for ig, ins_gt in enumerate(pts_in_gt[i_sem]):
                        union = (ins_pred | ins_gt)
                        intersect = (ins_pred & ins_gt)
                        iou = float(np.sum(intersect)) / np.sum(union)

                        if iou > ovmax:
                            ovmax = iou
                            igmax = ig

                    if ovmax >= at:
                        if gtflag[igmax] == 0:
                            tp[ip] = 1  # true
                            gtflag[igmax] = 1
                        else:
                            fp[ip] = 1  # multiple det
                    else:
                        fp[ip] = 1  # false positive

                tpsins[i_sem] += tp
                fpsins[i_sem] += fp

            ### Semantic Segmentation
            un, indices = np.unique(seg_gt, return_index=True)
            for segid in un:
                intersect = np.sum((seg_pred == segid) & (seg_gt == segid))
                union = np.sum((seg_pred == segid) | (seg_gt == segid))
                intersections[segid] += intersect
                unions[segid] += union
            iou = intersections / unions
            for i_iou, iou_ in enumerate(iou):
                if not np.isnan(iou_):
                    ious[i_iou] += iou_
                    totalnums[i_iou] += 1

        ap = np.zeros(NUM_CATEGORY)
        for i_sem in range(NUM_CATEGORY):
            ap[i_sem], _, _ = eval_3d_perclass(tpsins[i_sem], fpsins[i_sem],
                                               total_sgpn[i_sem])

        print('Instance Segmentation AP:', ap)
        print('Instance Segmentation mAP:', np.mean(ap))
        print('Semantic Segmentation IoU:', ious / totalnums)
        print('Semantic Segmentation Acc: %f', total_acc / total_seen)
예제 #3
0
파일: valid.py 프로젝트: zhenglongyu/SGPN
def predict():
    is_training = False

    with tf.device('/gpu:' + str(gpu_to_use)):
        is_training_ph = tf.placeholder(tf.bool, shape=())

        pointclouds_ph, ptsseglabel_ph, ptsgroup_label_ph, _, _, _ = \
            model.placeholder_inputs(BATCH_SIZE, POINT_NUM, NUM_GROUPS, NUM_CATEGORY)

        group_mat_label = tf.matmul(
            ptsgroup_label_ph, tf.transpose(ptsgroup_label_ph, perm=[0, 2, 1]))
        net_output = model.get_model(pointclouds_ph,
                                     is_training_ph,
                                     group_cate_num=NUM_CATEGORY)

    # Add ops to save and restore all the variables.
    saver = tf.train.Saver()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True

    with tf.Session(config=config) as sess:

        # Restore variables from disk.

        ckptstate = tf.train.get_checkpoint_state(PRETRAINED_MODEL_PATH)
        if ckptstate is not None:
            LOAD_MODEL_FILE = os.path.join(
                PRETRAINED_MODEL_PATH,
                os.path.basename(ckptstate.model_checkpoint_path))
            saver.restore(sess, LOAD_MODEL_FILE)
            print("Model loaded in file: %s" % LOAD_MODEL_FILE)
        else:
            print("Fail to load modelfile: %s" % PRETRAINED_MODEL_PATH)

        ths = np.zeros(NUM_CATEGORY)
        ths_ = np.zeros(NUM_CATEGORY)
        cnt = np.zeros(NUM_CATEGORY)
        min_groupsize = np.zeros(NUM_CATEGORY)
        min_groupsize_cnt = np.zeros(NUM_CATEGORY)

        for shape_idx in range(len_pts_files):

            cur_train_filename = test_file_list[shape_idx]

            if not os.path.exists(cur_train_filename):
                continue
            cur_data, cur_group, _, cur_seg = provider.loadDataFile_with_groupseglabel_stanfordindoor(
                cur_train_filename)

            if OUTPUT_VERBOSE:
                pts = np.reshape(cur_data, [-1, 9])
                output_point_cloud_rgb(
                    pts[:, 6:], pts[:, 3:6],
                    os.path.join(OUTPUT_DIR, '%d_pts.obj' % (shape_idx)))

            pts_label_one_hot, pts_label_mask = model.convert_seg_to_one_hot(
                cur_seg)
            pts_group_label, _ = model.convert_groupandcate_to_one_hot(
                cur_group)
            num_data = cur_data.shape[0]

            cur_seg_flatten = np.reshape(cur_seg, [-1])
            un, indices = np.unique(cur_group, return_index=True)
            for iu, u in enumerate(un):
                groupsize = np.sum(cur_group == u)
                groupcate = cur_seg_flatten[indices[iu]]
                min_groupsize[groupcate] += groupsize
                # print groupsize, min_groupsize[groupcate]/min_groupsize_cnt[groupcate]
                min_groupsize_cnt[groupcate] += 1

            for j in range(num_data):

                print("Processsing: Shape [%d] Block[%d]" % (shape_idx, j))

                pts = cur_data[j, ...]

                feed_dict = {
                    pointclouds_ph: np.expand_dims(pts, 0),
                    ptsseglabel_ph: np.expand_dims(pts_label_one_hot[j, ...],
                                                   0),
                    ptsgroup_label_ph: np.expand_dims(pts_group_label[j, ...],
                                                      0),
                    is_training_ph: is_training,
                }

                pts_corr_val0, pred_confidence_val0, ptsclassification_val0, pts_corr_label_val0 = \
                                        sess.run([net_output['simmat'],
                                                  net_output['conf'],
                                                  net_output['semseg'],
                                                  group_mat_label],
                                                  feed_dict=feed_dict)
                seg = cur_seg[j, ...]
                ins = cur_group[j, ...]

                pts_corr_val = np.squeeze(pts_corr_val0[0])
                pred_confidence_val = np.squeeze(pred_confidence_val0[0])
                ptsclassification_val = np.argmax(np.squeeze(
                    ptsclassification_val0[0]),
                                                  axis=1)

                pts_corr_label_val = np.squeeze(1 - pts_corr_label_val0)
                seg = np.squeeze(seg)
                ins = np.squeeze(ins)

                ind = (seg == 8)
                pts_corr_val0 = (pts_corr_val > 1.).astype(np.float)
                print np.mean(
                    np.transpose(np.abs(pts_corr_label_val[ind] -
                                        pts_corr_val0[ind]),
                                 axes=[1, 0])[ind])

                ths, ths_, cnt = Get_Ths(pts_corr_val, seg, ins, ths, ths_,
                                         cnt)
                print ths / cnt

                if OUTPUT_VERBOSE:
                    un, indices = np.unique(ins, return_index=True)
                    for ii, id in enumerate(indices):
                        corr = pts_corr_val[id].copy()
                        output_scale_point_cloud(
                            pts[:, 6:], np.float32(corr),
                            os.path.join(
                                OUTPUT_DIR, '%d_%d_%d_%d_scale.obj' %
                                (shape_idx, j, un[ii], seg[id])))
                        corr = pts_corr_label_val[id]
                        output_scale_point_cloud(
                            pts[:, 6:], np.float32(corr),
                            os.path.join(
                                OUTPUT_DIR, '%d_%d_%d_%d_scalegt.obj' %
                                (shape_idx, j, un[ii], seg[id])))
                    output_scale_point_cloud(
                        pts[:, 6:], np.float32(pred_confidence_val),
                        os.path.join(OUTPUT_DIR,
                                     '%d_%d_conf.obj' % (shape_idx, j)))
                    output_color_point_cloud(
                        pts[:, 6:], ptsclassification_val.astype(np.int32),
                        os.path.join(OUTPUT_DIR, '%d_seg.obj' % (shape_idx)))

        ths = [
            ths[i] / cnt[i] if cnt[i] != 0 else 0.2 for i in range(len(cnt))
        ]
        np.savetxt(os.path.join(RESTORE_DIR, 'pergroup_thres.txt'), ths)

        min_groupsize = [
            int(float(min_groupsize[i]) /
                min_groupsize_cnt[i]) if min_groupsize_cnt[i] != 0 else 0
            for i in range(len(min_groupsize))
        ]
        np.savetxt(os.path.join(RESTORE_DIR, 'mingroupsize.txt'),
                   min_groupsize)
예제 #4
0
def predict():
    is_training = False

    with tf.device('/gpu:' + str(gpu_to_use)):
        is_training_ph = tf.placeholder(tf.bool, shape=())

        pointclouds_ph, ptsseglabel_ph, ptsgroup_label_ph, _, _, _ = \
            model.placeholder_inputs(BATCH_SIZE, POINT_NUM, NUM_GROUPS, NUM_CATEGORY)

        net_output = model.get_model(pointclouds_ph, is_training_ph, group_cate_num=NUM_CATEGORY)
        group_mat_label = tf.matmul(ptsgroup_label_ph, tf.transpose(ptsgroup_label_ph, perm=[0, 2, 1])) #BxNxN: (i,j) if i and j in the same group

    # Add ops to save and restore all the variables.

    saver = tf.train.Saver()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True

    with tf.Session(config=config) as sess:

        ckptstate = tf.train.get_checkpoint_state(PRETRAINED_MODEL_PATH)
        if ckptstate is not None:
            LOAD_MODEL_FILE = os.path.join(PRETRAINED_MODEL_PATH,os.path.basename(ckptstate.model_checkpoint_path))
            saver.restore(sess, LOAD_MODEL_FILE)
            print("Model loaded in file: %s" % LOAD_MODEL_FILE)
        else:
            print("Fail to load modelfile: %s" % PRETRAINED_MODEL_PATH)


        for shape_idx in range(len_pts_files):

            cur_train_filename = test_file_list[shape_idx]

            if not os.path.exists(cur_train_filename):
                continue
            cur_data, cur_group, _, cur_seg = provider.loadDataFile_with_groupseglabel_stanfordindoor(cur_train_filename)

            seg_output = np.zeros_like(cur_seg)
            segrefine_output = np.zeros_like(cur_seg)
            group_output = np.zeros_like(cur_group)
            conf_output = np.zeros_like(cur_group).astype(np.float)

            pts_label_one_hot, pts_label_mask = model.convert_seg_to_one_hot(cur_seg)
            pts_group_label, _ = model.convert_groupandcate_to_one_hot(cur_group)
            num_data = cur_data.shape[0]

            gap = 5e-3
            volume_num = int(1. / gap)+1
            volume = -1* np.ones([volume_num,volume_num,volume_num]).astype(np.int32)
            volume_seg = -1* np.ones([volume_num,volume_num,volume_num, NUM_CATEGORY]).astype(np.int32)

            intersections = np.zeros(NUM_CATEGORY)
            unions = np.zeros(NUM_CATEGORY)
            print('[%d / %d] Block Number: %d' % (shape_idx, len_pts_files, num_data))
            print('Loading train file %s' % (cur_train_filename))

            flag = True
            for j in range(num_data):

                pts = cur_data[j,...]

                feed_dict = {
                    pointclouds_ph: np.expand_dims(pts,0),
                    ptsseglabel_ph: np.expand_dims(pts_label_one_hot[j,...],0),
                    ptsgroup_label_ph: np.expand_dims(pts_group_label[j,...],0),
                    is_training_ph: is_training,
                }

                pts_corr_val0, pred_confidence_val0, ptsclassification_val0, pts_corr_label_val0 = \
                    sess.run([net_output['simmat'],
                              net_output['conf'],
                              net_output['semseg'],
                              group_mat_label],
                              feed_dict=feed_dict)

                seg = cur_seg[j,...]
                ins = cur_group[j,...]

                pts_corr_val = np.squeeze(pts_corr_val0[0]) #NxG
                pred_confidence_val = np.squeeze(pred_confidence_val0[0])
                ptsclassification_val = np.argmax(np.squeeze(ptsclassification_val0[0]),axis=1)

                seg = np.squeeze(seg)
                # print label_bin

                try:
                    groupids_block, refineseg, group_seg = GroupMerging_old(pts_corr_val, pred_confidence_val, ptsclassification_val, label_bin)  # yolo_to_groupt(pts_corr_val, pts_corr_label_val0[0], seg,t=5)
                    groupids = BlockMerging(volume, volume_seg, pts[:,6:], groupids_block.astype(np.int32), group_seg, gap)


                seg_output[j,:] = ptsclassification_val
                segrefine_output[j,:] = refineseg
                group_output[j,:] = groupids
                conf_output[j,:] = pred_confidence_val

            ###### Generate Results for Evaluation

            basefilename = os.path.basename(cur_train_filename).split('.')[-2]
            scene_fn = os.path.join(OUTPUT_DIR, '%s.txt' % basefilename)
            f_scene = open(scene_fn, 'w')
            scene_gt_fn = os.path.join(GT_DIR, '%s.txt' % basefilename)
            group_pred = group_output.reshape(-1)
            seg_pred = seg_output.reshape(-1)
            conf = conf_output.reshape(-1)
            pts = cur_data.reshape([-1, 9])

            # filtering
            x = (pts[:, 6] / gap).astype(np.int32)
            y = (pts[:, 7] / gap).astype(np.int32)
            z = (pts[:, 8] / gap).astype(np.int32)
            for i in range(group_pred.shape[0]):
                if volume[x[i], y[i], z[i]] != -1:
                    group_pred[i] = volume[x[i], y[i], z[i]]

            un = np.unique(group_pred)
            pts_in_pred = [[] for itmp in range(NUM_CATEGORY)]
            group_pred_final = -1 * np.ones_like(group_pred)
            grouppred_cnt = 0

            for ig, g in enumerate(un): #each object in prediction
                if g == -1:
                    continue
                obj_fn = "predicted_masks/%s_%d.txt" % (basefilename, ig)
                tmp = (group_pred == g)
                sem_seg_g = int(stats.mode(seg_pred[tmp])[0])
                if np.sum(tmp) > 0.25 * min_num_pts_in_group[sem_seg_g]:
                    pts_in_pred[sem_seg_g] += [tmp]
                    group_pred_final[tmp] = grouppred_cnt
                    conf_obj = np.mean(conf[tmp])
                    grouppred_cnt += 1
                    f_scene.write("%s %d %f\n" % (obj_fn, sem_seg_g, conf_obj))
                    np.savetxt(os.path.join(OUTPUT_DIR, obj_fn), tmp.astype(np.int), fmt='%d')

            seg_gt = cur_seg.reshape(-1)
            group_gt = cur_group.reshape(-1)
            groupid_gt = seg_gt * 1000 + group_gt
            np.savetxt(scene_gt_fn, groupid_gt.astype(np.int64), fmt='%d')

            f_scene.close()

            if output_verbose:
                output_color_point_cloud(pts[:, 6:], seg_pred.astype(np.int32),
                                         os.path.join(OUTPUT_DIR, '%s_segpred.obj' % (obj_fn)))
                output_color_point_cloud(pts[:, 6:], group_pred_final.astype(np.int32),
                                         os.path.join(OUTPUT_DIR, '%s_grouppred.obj' % (obj_fn)))
예제 #5
0
def train():
    with tf.Graph().as_default():
        with tf.device('/gpu:' + str(FLAGS.gpu)):
            batch = tf.Variable(0, trainable=False, name='batch')
            learning_rate = tf.train.exponential_decay(
                BASE_LEARNING_RATE,  # base learning rate
                batch * BATCH_SIZE,  # global_var indicating the number of steps
                DECAY_STEP,  # step size
                DECAY_RATE,  # decay rate
                staircase=True  # Stair-case or continuous decreasing
            )
            bn_decay = get_bn_decay(batch)
            learning_rate = tf.maximum(learning_rate, LEARNING_RATE_CLIP)

            lr_op = tf.summary.scalar('learning_rate', learning_rate)

            pointclouds_ph, ptsseglabel_ph, ptsseglabel_onehot_ph, ptsgroup_label_ph, pts_seglabel_mask_ph, pts_group_mask_ph, alpha_ph = \
                model.placeholder_inputs(BATCH_SIZE, POINT_NUM, NUM_GROUPS, NUM_CATEGORY)
            is_training_ph = tf.placeholder(tf.bool, shape=())

            labels = {'ptsgroup': ptsgroup_label_ph,
                      'semseg': ptsseglabel_ph,
                      'semseg_onehot': ptsseglabel_onehot_ph,
                      'semseg_mask': pts_seglabel_mask_ph,
                      'group_mask': pts_group_mask_ph}

            net_output = model.get_model(pointclouds_ph, is_training_ph, group_cate_num=NUM_CATEGORY, m=MARGINS[0], bn_decay=bn_decay)
            ptsseg_loss, simmat_loss, loss, grouperr, same, same_cnt, diff, diff_cnt, pos, pos_cnt = model.get_loss(net_output, labels, alpha_ph, MARGINS)

            total_training_loss_ph = tf.placeholder(tf.float32, shape=())
            group_err_loss_ph = tf.placeholder(tf.float32, shape=())
            total_train_loss_sum_op = tf.summary.scalar('total_training_loss', total_training_loss_ph)
            group_err_op = tf.summary.scalar('group_err_loss', group_err_loss_ph)

        train_variables = tf.trainable_variables()

        trainer = tf.train.AdamOptimizer(learning_rate)
        train_op = trainer.minimize(loss, var_list=train_variables, global_step=batch)
        train_op_pretrain = trainer.minimize(ptsseg_loss, var_list=train_variables, global_step=batch)
        train_op_5epoch = trainer.minimize(simmat_loss, var_list=train_variables, global_step=batch)

        loader = tf.train.Saver([v for v in tf.all_variables()#])
                                 if
                                   ('conf_logits' not in v.name) and
                                    ('Fsim' not in v.name) and
                                    ('Fsconf' not in v.name) and
                                    ('batch' not in v.name)
                                ])
        saver = tf.train.Saver([v for v in tf.all_variables()], max_to_keep=200)

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        sess = tf.Session(config=config)

        init = tf.global_variables_initializer()
        sess.run(init)

        train_writer = tf.summary.FileWriter(SUMMARIES_FOLDER + '/train', sess.graph)

        fcmd = open(os.path.join(LOG_STORAGE_PATH, 'cmd.txt'), 'w')
        fcmd.write(str(FLAGS))
        fcmd.close()

        flog = open(os.path.join(LOG_STORAGE_PATH, 'log.txt'), 'w')

        ckptstate = tf.train.get_checkpoint_state(PRETRAINED_MODEL_PATH)
        if ckptstate is not None:
            LOAD_MODEL_FILE = os.path.join(PRETRAINED_MODEL_PATH, os.path.basename(ckptstate.model_checkpoint_path))
            loader.restore(sess, LOAD_MODEL_FILE)
            printout(flog, "Model loaded in file: %s" % LOAD_MODEL_FILE)
        else:
            printout(flog, "Fail to load modelfile: %s" % PRETRAINED_MODEL_PATH)


        ## load test data into memory
        test_data = []
        test_group = []
        test_seg = []
        test_smpw = []
        for i in range(len(TEST_DATASET)):
            print(i)
            cur_data, cur_seg, cur_group, cur_smpw = get_test_batch(TEST_DATASET, i)
            test_data += [cur_data]
            test_group += [cur_group]
            test_seg += [cur_seg]
            test_smpw += [cur_smpw]

        test_data = np.concatenate(test_data,axis=0)
        test_group = np.concatenate(test_group,axis=0)
        test_seg = np.concatenate(test_seg,axis=0)
        test_smpw = np.concatenate(test_smpw,axis=0)
        num_data_test = test_data.shape[0]
        num_batch_test = num_data_test // BATCH_SIZE

        def train_one_epoch(epoch_num):

            ### NOTE: is_training = False: We do not update bn parameters during training due to the small batch size. This requires pre-training PointNet with large batchsize (say 32).
            if PRETRAIN:
                is_training = True
            else:
                is_training = False

            total_loss = 0.0
            total_grouperr = 0.0
            total_same = 0.0
            total_diff = 0.0
            total_pos = 0.0
            same_cnt0 = 0

            train_idxs = np.arange(0, len(TRAIN_DATASET))
            np.random.shuffle(train_idxs)
            num_batches = len(TRAIN_DATASET)//BATCH_SIZE
            for batch_idx in range(num_batches):
                print('{}/{}'.format(batch_idx, num_batches))
                start_idx = batch_idx * BATCH_SIZE
                end_idx = (batch_idx+1) * BATCH_SIZE
                batch_data, batch_label, batch_group, batch_smpw = get_batch(TRAIN_DATASET, train_idxs, start_idx, end_idx)
                aug_data = provider.rotate_point_cloud_z(batch_data)
                pts_label_one_hot = model.convert_seg_to_one_hot(batch_label)

                if PRETRAIN:
                    feed_dict = {
                        pointclouds_ph: aug_data, 
                        ptsseglabel_ph: batch_label,
                        ptsseglabel_onehot_ph: pts_label_one_hot,
                        pts_seglabel_mask_ph: batch_smpw,
                        is_training_ph: is_training,
                        alpha_ph: min(10., (float(epoch_num) / 5.) * 2. + 2.),
                    }
                    _, loss_val  = sess.run([train_op_pretrain, ptsseg_loss], feed_dict=feed_dict)
                    total_loss += loss_val
                    if batch_idx % 10 == 9:
                        printout(flog, 'Batch: %d, loss: %f' % (batch_idx, total_loss/10))
                        total_loss = 0.0
                else:
                    pts_group_label, pts_group_mask = model.convert_groupandcate_to_one_hot(batch_group)
                    feed_dict = {
                        pointclouds_ph: batch_data,
                        ptsseglabel_ph: batch_label,
                        ptsseglabel_onehot_ph: pts_label_one_hot,
                        pts_seglabel_mask_ph: batch_smpw,
                        ptsgroup_label_ph: pts_group_label,
                        pts_group_mask_ph: pts_group_mask,
                        is_training_ph: is_training,
                        alpha_ph: min(10., (float(epoch_num) / 5.) * 2. + 2.),
                    }

                    if epoch_num < 20:
                        _, loss_val, simmat_val, grouperr_val, same_val, same_cnt_val, diff_val, diff_cnt_val, pos_val, pos_cnt_val = sess.run([train_op_5epoch, simmat_loss, net_output['simmat'], grouperr, same, same_cnt, diff, diff_cnt, pos, pos_cnt], feed_dict=feed_dict)
                    else:
                        _, loss_val, simmat_val, grouperr_val, same_val, same_cnt_val, diff_val, diff_cnt_val, pos_val, pos_cnt_val = sess.run([train_op, loss, net_output['simmat'], grouperr, same, same_cnt, diff, diff_cnt, pos, pos_cnt], feed_dict=feed_dict)

                    total_loss += loss_val
                    total_grouperr += grouperr_val
                    total_diff += (diff_val / diff_cnt_val)
                    if same_cnt_val > 0:
                        total_same += same_val / same_cnt_val
                        same_cnt0 += 1
                    total_pos += pos_val / pos_cnt_val

                    if batch_idx % 10 == 9:
                        printout(flog, 'Batch: %d, loss: %f, grouperr: %f, same: %f, diff: %f, pos: %f' % (batch_idx, total_loss/10, total_grouperr/10, total_same/same_cnt0, total_diff/10, total_pos/10))

                        lr_sum, batch_sum, train_loss_sum, group_err_sum = sess.run( \
                            [lr_op, batch, total_train_loss_sum_op, group_err_op], \
                            feed_dict={total_training_loss_ph: total_loss / 10.,
                                       group_err_loss_ph: total_grouperr / 10., })

                        train_writer.add_summary(train_loss_sum, batch_sum)
                        train_writer.add_summary(lr_sum, batch_sum)
                        train_writer.add_summary(group_err_sum, batch_sum)

                        total_grouperr = 0.0
                        total_loss = 0.0
                        total_diff = 0.0
                        total_same = 0.0
                        total_pos = 0.0
                        same_cnt0 = 0

            cp_filename = saver.save(sess, os.path.join(MODEL_STORAGE_PATH, 'epoch_' + str(epoch_num + 1) + '.ckpt'))
            printout(flog, 'Successfully store the checkpoint model into ' + cp_filename)

        def val_one_epoch(epoch_num):
            is_training = False

            def evaluate_confusion(confusion_matrix, epoch):
                conf = confusion_matrix.value()
                total_correct = 0
                valids = np.zeros(NUM_CATEGORY, dtype=np.float32)
                for c in range(NUM_CATEGORY):
                    num = conf[c,:].sum()
                    valids[c] = -1 if num == 0 else float(conf[c][c]) / float(num)
                    total_correct += conf[c][c]
                instance_acc = -1 if conf.sum() == 0 else float(total_correct) / float(conf.sum())
                avg_acc = -1 if np.all(np.equal(valids, -1)) else np.mean(valids[np.not_equal(valids, -1)])
                print('Epoch: {}\tAcc(inst): {:.6f}\tAcc(avg): {:.6f}'.format(epoch, instance_acc, avg_acc))
                for class_ind, class_acc in enumerate(valids[np.not_equal(valids, -1)]):
                    print('{}: {}'.format(class_ind, class_acc))
                with open(os.path.join(LOG_STORAGE_PATH, 'ACC_{}.txt'.format(epoch)), 'w') as f:
                    f.write('Epoch: {}\tAcc(inst): {:.6f}\tAcc(avg): {:.6f}'.format(epoch, instance_acc, avg_acc))
                    for class_ind, class_acc in enumerate(valids[np.not_equal(valids, -1)]):
                        f.write('{}: {}\n'.format(class_ind, class_acc))

            confusion_val = tnt.meter.ConfusionMeter(NUM_CATEGORY)
            for j in range(0, num_batch_test):
                print('{}/{}'.format(j, num_batch_test))
                start_idx = j * BATCH_SIZE
                end_idx = (j + 1) * BATCH_SIZE
                pts_label_one_hot = model.convert_seg_to_one_hot(test_seg[start_idx:end_idx])
                feed_dict = {
                    pointclouds_ph: test_data[start_idx:end_idx,...],
                    ptsseglabel_ph: test_seg[start_idx:end_idx],
                    ptsseglabel_onehot_ph: pts_label_one_hot,
                    pts_seglabel_mask_ph: test_smpw[start_idx:end_idx, ...],
                    is_training_ph: is_training,
                    alpha_ph: min(10., (float(epoch_num) / 5.) * 2. + 2.),
                }

                ptsclassification_val0 = sess.run([net_output['semseg']], feed_dict=feed_dict)
                ptsclassification_val = torch.from_numpy(ptsclassification_val0[0]).view(-1, NUM_CATEGORY)
                ptsclassification_gt = torch.from_numpy(pts_label_one_hot).view(-1, NUM_CATEGORY)
                #import ipdb
                #ipdb.set_trace()
                #pc_util.write_obj_color(np.reshape(test_data[:BATCH_SIZE,:,:3], [-1,3])[:,:3], np.argmax(ptsclassification_val.numpy(), 1), 'pred3.obj')
                #pc_util.write_obj_color(np.reshape(test_data[:BATCH_SIZE,:,:3], [-1,3])[:,:3], np.argmax(ptsclassification_gt.numpy(), 1), 'gt.obj')
                confusion_val.add(target=ptsclassification_gt, predicted=ptsclassification_val)
            evaluate_confusion(confusion_val, epoch_num)


        if not os.path.exists(MODEL_STORAGE_PATH):
            os.mkdir(MODEL_STORAGE_PATH)

        for epoch in range(TRAINING_EPOCHES):
            printout(flog, '\n>>> Training for the epoch %d/%d ...' % (epoch, TRAINING_EPOCHES))
            train_one_epoch(epoch)
            flog.flush()
            if PRETRAIN:
                val_one_epoch(epoch)

        flog.close()
예제 #6
0
파일: train.py 프로젝트: zhenglongyu/SGPN
def train():
    with tf.Graph().as_default():
        with tf.device('/gpu:' + str(FLAGS.gpu)):
            batch = tf.Variable(0, trainable=False, name='batch')
            learning_rate = tf.train.exponential_decay(
                BASE_LEARNING_RATE,  # base learning rate
                batch * BATCH_SIZE,  # global_var indicating the number of steps
                DECAY_STEP,  # step size
                DECAY_RATE,  # decay rate
                staircase=True  # Stair-case or continuous decreasing
            )
            learning_rate = tf.maximum(learning_rate, LEARNING_RATE_CLIP)

            lr_op = tf.summary.scalar('learning_rate', learning_rate)

            pointclouds_ph, ptsseglabel_ph, ptsgroup_label_ph, pts_seglabel_mask_ph, pts_group_mask_ph, alpha_ph = \
                model.placeholder_inputs(BATCH_SIZE, POINT_NUM, NUM_GROUPS, NUM_CATEGORY)
            is_training_ph = tf.placeholder(tf.bool, shape=())

            labels = {'ptsgroup': ptsgroup_label_ph,
                      'semseg': ptsseglabel_ph,
                      'semseg_mask': pts_seglabel_mask_ph,
                      'group_mask': pts_group_mask_ph}

            net_output = model.get_model(pointclouds_ph, is_training_ph, group_cate_num=NUM_CATEGORY, m=MARGINS[0])
            loss, grouperr, same, same_cnt, diff, diff_cnt, pos, pos_cnt = model.get_loss(net_output, labels, alpha_ph, MARGINS)

            total_training_loss_ph = tf.placeholder(tf.float32, shape=())
            group_err_loss_ph = tf.placeholder(tf.float32, shape=())
            total_train_loss_sum_op = tf.summary.scalar('total_training_loss', total_training_loss_ph)
            group_err_op = tf.summary.scalar('group_err_loss', group_err_loss_ph)

        train_variables = tf.trainable_variables()

        trainer = tf.train.AdamOptimizer(learning_rate)
        train_op = trainer.minimize(loss, var_list=train_variables, global_step=batch)

        loader = tf.train.Saver([v for v in tf.all_variables()#])
                                 if
                                   ('conf_logits' not in v.name) and
                                    ('Fsim' not in v.name) and
                                    ('Fsconf' not in v.name) and
                                    ('batch' not in v.name)
                                ])

        saver = tf.train.Saver([v for v in tf.all_variables()])

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        sess = tf.Session(config=config)

        init = tf.global_variables_initializer()
        sess.run(init)

        train_writer = tf.summary.FileWriter(SUMMARIES_FOLDER + '/train', sess.graph)

        import glob

        train_file_list= glob.glob(os.path.join(TRAINING_FILE_LIST,'*.h5'))
        # train_file_list=provider.getDataFiles(TRAINING_FILE_LIST)
        num_train_file = len(train_file_list)

        fcmd = open(os.path.join(LOG_STORAGE_PATH, 'cmd.txt'), 'w')
        fcmd.write(str(FLAGS))
        fcmd.close()

        flog = open(os.path.join(LOG_STORAGE_PATH, 'log.txt'), 'w')

        ckptstate = tf.train.get_checkpoint_state(PRETRAINED_MODEL_PATH)
        if ckptstate is not None:
            LOAD_MODEL_FILE = os.path.join(PRETRAINED_MODEL_PATH, os.path.basename(ckptstate.model_checkpoint_path))
            loader.restore(sess, LOAD_MODEL_FILE)
            printout(flog, "Model loaded in file: %s" % LOAD_MODEL_FILE)
        else:
            printout(flog, "Fail to load modelfile: %s" % PRETRAINED_MODEL_PATH)


        train_file_idx = np.arange(0, len(train_file_list))
        np.random.shuffle(train_file_idx)

        ## load all data into memory
        all_data = []
        all_group = []
        all_seg = []
        for i in range(num_train_file):
            cur_train_filename = train_file_list[train_file_idx[i]]
            # printout(flog, 'Loading train file ' + cur_train_filename)
            cur_data, cur_group, _, cur_seg = provider.loadDataFile_with_groupseglabel_stanfordindoor(cur_train_filename)
            all_data += [cur_data]
            all_group += [cur_group]
            all_seg += [cur_seg]

        all_data = np.concatenate(all_data,axis=0)
        all_group = np.concatenate(all_group,axis=0)
        all_seg = np.concatenate(all_seg,axis=0)

        num_data = all_data.shape[0]
        num_batch = num_data // BATCH_SIZE

        def train_one_epoch(epoch_num):

            ### NOTE: is_training = False: We do not update bn parameters during training due to the small batch size. This requires pre-training PointNet with large batchsize (say 32).
            is_training = False

            order = np.arange(num_data)
            np.random.shuffle(order)

            total_loss = 0.0
            total_grouperr = 0.0
            total_same = 0.0
            total_diff = 0.0
            total_pos = 0.0
            same_cnt0 = 0

            for j in range(num_batch):
                begidx = j * BATCH_SIZE
                endidx = (j + 1) * BATCH_SIZE

                pts_label_one_hot, pts_label_mask = model.convert_seg_to_one_hot(all_seg[order[begidx: endidx]])
                pts_group_label, pts_group_mask = model.convert_groupandcate_to_one_hot(all_group[order[begidx: endidx]])

                feed_dict = {
                    pointclouds_ph: all_data[order[begidx: endidx], ...],
                    ptsseglabel_ph: pts_label_one_hot,
                    ptsgroup_label_ph: pts_group_label,
                    pts_seglabel_mask_ph: pts_label_mask,
                    pts_group_mask_ph: pts_group_mask,
                    is_training_ph: is_training,
                    alpha_ph: min(10., (float(epoch_num) / 5.) * 2. + 2.),
                }

                _, loss_val, simmat_val, grouperr_val, same_val, same_cnt_val, diff_val, diff_cnt_val, pos_val, pos_cnt_val = sess.run([train_op, loss, net_output['simmat'], grouperr, same, same_cnt, diff, diff_cnt, pos, pos_cnt], feed_dict=feed_dict)
                total_loss += loss_val
                total_grouperr += grouperr_val
                if diff_cnt_val!=0:
                    total_diff += (diff_val / diff_cnt_val)
                if same_cnt_val > 0:
                    total_same += same_val / same_cnt_val
                    same_cnt0 += 1
                total_pos += pos_val / pos_cnt_val


                if j % 10 == 9:
                    if same_cnt0!=0:
                        printout(flog, 'Batch: %d, loss: %f, grouperr: %f, same: %f, diff: %f, pos: %f' % (j, total_loss/10, total_grouperr/10, total_same/same_cnt0, total_diff/10, total_pos/10))
                    else:
                        printout(flog, 'Batch: %d, loss: %f, grouperr: %f, same: %f, diff: %f, pos: %f' % (j, total_loss/10, total_grouperr/10, 1e10, total_diff/10, total_pos/10))

                    lr_sum, batch_sum, train_loss_sum, group_err_sum = sess.run( \
                        [lr_op, batch, total_train_loss_sum_op, group_err_op], \
                        feed_dict={total_training_loss_ph: total_loss / 10.,
                                   group_err_loss_ph: total_grouperr / 10., })

                    train_writer.add_summary(train_loss_sum, batch_sum)
                    train_writer.add_summary(lr_sum, batch_sum)
                    train_writer.add_summary(group_err_sum, batch_sum)

                    total_grouperr = 0.0
                    total_loss = 0.0
                    total_diff = 0.0
                    total_same = 0.0
                    total_pos = 0.0
                    same_cnt0 = 0



            cp_filename = saver.save(sess,
                                     os.path.join(MODEL_STORAGE_PATH, 'epoch_' + str(epoch_num + 1) + '.ckpt'))
            printout(flog, 'Successfully store the checkpoint model into ' + cp_filename)

        if not os.path.exists(MODEL_STORAGE_PATH):
            os.mkdir(MODEL_STORAGE_PATH)

        for epoch in range(TRAINING_EPOCHES):
            printout(flog, '\n>>> Training for the epoch %d/%d ...' % (epoch, TRAINING_EPOCHES))

            train_file_idx = np.arange(0, len(train_file_list))
            np.random.shuffle(train_file_idx)

            train_one_epoch(epoch)
            flog.flush()

            cp_filename = saver.save(sess,
                                     os.path.join(MODEL_STORAGE_PATH, 'epoch_' + str(epoch + 1) + '.ckpt'))
            printout(flog, 'Successfully store the checkpoint model into ' + cp_filename)


        flog.close()
예제 #7
0
def predict():
    """
    Load the selected checkpoint and predict the labels
    Write in the output directories both groundtruth and prediction
    This enable to visualize side to side the prediction and the true labels,
    and helps to debug the network
    """
    with tf.device('/gpu:' + str(GPU_INDEX)):
        pointclouds_pl, labels_pl, _ = MODEL.placeholder_inputs(
            1, NUM_POINT, hyperparams=PARAMS)
        print(tf.shape(pointclouds_pl))
        is_training_pl = tf.placeholder(tf.bool, shape=())

        # simple model
        pred, _ = MODEL.get_model(pointclouds_pl,
                                  is_training_pl,
                                  NUM_CLASSES,
                                  hyperparams=PARAMS)

        # Add ops to save and restore all the variables.
        saver = tf.train.Saver()

    # Create a session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.log_device_placement = False
    sess = tf.Session(config=config)

    # Restore variables from disk.
    saver.restore(sess, CHECKPOINT)
    print("Model restored.")

    ops = {
        'pointclouds_pl': pointclouds_pl,
        'labels_pl': labels_pl,
        'is_training_pl': is_training_pl,
        'pred': pred
    }

    if EXPORT_FULL_POINT_CLOUDS:
        OUTPUT_DIR_FULL_PC = os.path.join(OUTPUT_DIR,
                                          "full_scenes_predictions")
        if not os.path.exists(OUTPUT_DIR_FULL_PC): os.mkdir(OUTPUT_DIR_FULL_PC)
        nscenes = len(DATASET)
        p = 6 if PARAMS['use_color'] else 3
        scene_points = [np.array([]).reshape((0, p)) for i in range(nscenes)]
        ground_truth = [np.array([]) for i in range(nscenes)]
        predicted_labels = [np.array([]) for i in range(nscenes)]
        for i in range(N * nscenes):
            if i % 100 == 0 and i > 0:
                print("{} inputs generated".format(i))
            f, data, raw_data, true_labels, col, _ = DATASET.next_input(
                DROPOUT, True, False, predicting=True)
            if p == 6:
                raw_data = np.hstack((raw_data, col))
                data = np.hstack((data, col))
            pred_labels = predict_one_input(sess, ops, data)
            scene_points[f] = np.vstack((scene_points[f], raw_data))
            ground_truth[f] = np.hstack((ground_truth[f], true_labels))
            predicted_labels[f] = np.hstack((predicted_labels[f], pred_labels))
        filenames = DATASET.get_data_filenames()
        filenamesForExport = filenames[0:min(len(filenames), MAX_EXPORT)]
        print("{} point clouds to export".format(len(filenamesForExport)))
        for f, filename in enumerate(filenamesForExport):
            print("exporting file {} which has {} points".format(
                os.path.basename(filename), len(ground_truth[f])))
            pc_util.write_ply_color(
                scene_points[f][:, 0:3], ground_truth[f], OUTPUT_DIR_FULL_PC +
                "/{}_groundtruth.txt".format(os.path.basename(filename)),
                NUM_CLASSES)
            pc_util.write_ply_color(
                scene_points[f][:,
                                0:3], predicted_labels[f], OUTPUT_DIR_FULL_PC +
                "/{}_aggregated.txt".format(os.path.basename(filename)),
                NUM_CLASSES)
            np.savetxt(OUTPUT_DIR_FULL_PC +
                       "/{}_pred.txt".format(os.path.basename(filename)),
                       predicted_labels[f].reshape((-1, 1)),
                       delimiter=" ")
        print("done.")
        return

    if not os.path.exists(OUTPUT_DIR_GROUNDTRUTH):
        os.mkdir(OUTPUT_DIR_GROUNDTRUTH)
    if not os.path.exists(OUTPUT_DIR_PREDICTION):
        os.mkdir(OUTPUT_DIR_PREDICTION)

    # To add the histograms
    meta_hist_true = np.zeros(9)
    meta_hist_pred = np.zeros(9)

    for idx in range(N):
        data, true_labels, _, _ = DATASET.next_input(DROPOUT, True, False)
        # Ground truth
        print("Exporting scene number " + str(idx))
        pc_util.write_ply_color(
            data[:, 0:3], true_labels,
            OUTPUT_DIR_GROUNDTRUTH + "/{}_{}.txt".format(SET, idx),
            NUM_CLASSES)

        # Prediction
        pred_labels = predict_one_input(sess, ops, data)

        # Compute mean IoU
        iou, update_op = tf.metrics.mean_iou(tf.to_int64(true_labels),
                                             tf.to_int64(pred_labels),
                                             NUM_CLASSES)
        sess.run(tf.local_variables_initializer())
        update_op.eval(session=sess)
        print(sess.run(iou))

        hist_true, _ = np.histogram(true_labels, range(NUM_CLASSES + 1))
        hist_pred, _ = np.histogram(pred_labels, range(NUM_CLASSES + 1))

        # update meta histograms
        meta_hist_true += hist_true
        meta_hist_pred += hist_pred

        # print individual histograms
        print(hist_true)
        print(hist_pred)

        pc_util.write_ply_color(
            data[:, 0:3], pred_labels,
            "{}/{}_{}.txt".format(OUTPUT_DIR_PREDICTION, SET,
                                  idx), NUM_CLASSES)

    meta_hist_pred = (meta_hist_pred / sum(meta_hist_pred)) * 100
    meta_hist_true = (meta_hist_true / sum(meta_hist_true)) * 100
    print(LABELS_TEXT)
    print(meta_hist_true)
    print(meta_hist_pred)