示例#1
0
def predict():
    is_training = False
    
    with tf.device('/gpu:'+str(gpu_to_use)):
        images_ph, edges_ph = placeholder_inputs()
        is_training_ph = tf.placeholder(tf.bool, shape=())

        # simple model
        edge_logits = model.get_model(images_ph, is_training=is_training_ph)
        loss = model.get_loss(edge_logits, edges_ph)

    # Add ops to save and restore all the variables.
    saver = tf.train.Saver()

    # Later, launch the model, use the saver to restore variables from disk, and
    # do some work with the model.
    
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True

    with tf.Session(config=config) as sess:
        if not os.path.exists(output_dir):
            os.mkdir(output_dir)

        flog = open(os.path.join(output_dir, 'log.txt'), 'w')

        # Restore variables from disk.
        ckpt_dir = './train_results/trained_models'
        if not load_checkpoint(ckpt_dir, sess):
            sess.run(tf.global_variables_initializer())

        if not os.path.exists('data/mv-rnn'):
            os.makedirs('data/mv-rnn')
        for l in range(len(TESTING_FILE_LIST)):
            images = np.zeros((1, NUM_VIEWS, IMAGE_SIZE, IMAGE_SIZE, 1))
            model_name = TESTING_FILE_LIST[l]
            if not os.path.exists('data/mv-rnn/' + model_name):
                os.makedirs('data/mv-rnn/' + model_name)
            for v in range(NUM_VIEWS):
                images[0, v, :, :, 0] = np.array(scipy.ndimage.imread('data/rgb/' + model_name + '/RGB-' + str(v).zfill(3) + '.png', mode = 'L'), dtype=np.float32)
            edge_logits_val = sess.run(edge_logits, feed_dict={images_ph: images, is_training_ph: is_training})
            edges = sigmoid(edge_logits_val)
            for v in range(NUM_VIEWS):
                scipy.misc.imsave('data/mv-rnn' + '/' + model_name + '/MV-RNN-' + str(v).zfill(3) + '.png', edges[0, v, :, :, 0])

            printout(flog, '[%2d/%2d] model %s' % ((l+1), len(TESTING_FILE_LIST), TESTING_FILE_LIST[l]))
            printout(flog, '----------')
示例#2
0
文件: train.py 项目: Lester-liu/MVRNN
def train():
    with tf.Graph().as_default():
        with tf.device('/gpu:' + str(FLAGS.gpu)):
            images_ph, edges_ph = placeholder_inputs()
            is_training_ph = tf.placeholder(tf.bool, shape=())

            queue = tf.FIFOQueue(capacity=10*batch_size, dtypes=[tf.float32, tf.float32],\
                                                         shapes=[[NUM_VIEWS, IMAGE_SIZE, IMAGE_SIZE, 1], [NUM_VIEWS, IMAGE_SIZE, IMAGE_SIZE, 1]])
            enqueue_op = queue.enqueue([images_ph, edges_ph])
            dequeue_images_ph, dequeue_edges_ph = queue.dequeue_many(
                batch_size)

            # model and loss
            edge_logits = model.get_model(dequeue_images_ph,
                                          is_training=is_training_ph)
            loss = model.get_loss(edge_logits, dequeue_edges_ph)

            # optimization
            total_var = tf.trainable_variables()
            train_step = tf.train.AdamOptimizer(
                learning_rate=LEARNING_RATE).minimize(loss, var_list=total_var)

        # write logs to the disk
        flog = open(os.path.join(LOG_STORAGE_PATH, 'log.txt'), 'w')

        saver = tf.train.Saver()

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        sess = tf.Session(config=config)

        ckpt_dir = './train_results/trained_models'
        if not load_checkpoint(ckpt_dir, sess):
            sess.run(tf.global_variables_initializer())

        train_writer = tf.summary.FileWriter(SUMMARIES_FOLDER + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(SUMMARIES_FOLDER + '/test')

        fcmd = open(os.path.join(LOG_STORAGE_PATH, 'cmd.txt'), 'w')
        fcmd.write(str(FLAGS))
        fcmd.close()

        def train_one_epoch(epoch_num):
            is_training = True

            num_data = len(TRAINING_FILE_LIST)
            num_batch = num_data // batch_size
            loss_acc = 0.0
            display_mark = max([num_batch // 4, 1])
            for i in range(num_batch):
                _, loss_val = sess.run([train_step, loss],
                                       feed_dict={is_training_ph: is_training})
                loss_acc += loss_val
                if ((i + 1) % display_mark == 0):
                    printout(
                        flog, 'Epoch %3d/%3d - Iter %4d/%d' %
                        (epoch_num + 1, TRAINING_EPOCHES, i + 1, num_batch))
                    printout(flog, 'total loss: %f' % (loss_acc / (i + 1)))

            loss_acc = loss_acc * 1.0 / num_batch

            printout(flog, '\tMean total Loss: %f' % loss_acc)

        if not os.path.exists(MODEL_STORAGE_PATH):
            os.mkdir(MODEL_STORAGE_PATH)

        coord = tf.train.Coordinator()
        for num_thread in range(4):
            t = StoppableThread(target=load_and_enqueue,
                                args=(sess, enqueue_op, images_ph, edges_ph))
            t.setDaemon(True)
            t.start()
            coord.register_thread(t)

        for epoch in range(TRAINING_EPOCHES):
            printout(
                flog, '\n>>> Training for the epoch %d/%d ...' %
                (epoch + 1, TRAINING_EPOCHES))

            train_one_epoch(epoch)

            if (epoch + 1) % 1 == 0:
                cp_filename = saver.save(
                    sess,
                    os.path.join(MODEL_STORAGE_PATH,
                                 'epoch_' + str(epoch + 1) + '.ckpt'))
                printout(
                    flog, 'Successfully store the checkpoint model into ' +
                    cp_filename)

            flog.flush()
        flog.close()
示例#3
0
def train():
    with tf.Graph().as_default():
        with tf.device('/gpu:' + str(FLAGS.gpu)):
            pointgrid_ph, cat_label_ph, seg_label_ph = placeholder_inputs()
            is_training_ph = tf.placeholder(tf.bool, shape=())

            queue = tf.FIFOQueue(capacity=20*batch_size, dtypes=[tf.float32, tf.float32, tf.float32],\
                                                         shapes=[[model.N, model.N, model.N, model.NUM_FEATURES],\
                                                                 [model.NUM_CATEGORY],
                                                                 [model.N, model.N, model.N, model.K+1, model.NUM_SEG_PART]])
            enqueue_op = queue.enqueue(
                [pointgrid_ph, cat_label_ph, seg_label_ph])
            dequeue_pointgrid, dequeue_cat_label, dequeue_seg_label = queue.dequeue_many(
                batch_size)

            # model
            pred_cat, pred_seg = model.get_model(dequeue_pointgrid,
                                                 is_training=is_training_ph)

            # loss
            total_loss, cat_loss, seg_loss = model.get_loss(
                pred_cat, dequeue_cat_label, pred_seg, dequeue_seg_label)

            # optimization
            total_var = tf.trainable_variables()
            step = tf.train.AdamOptimizer(
                learning_rate=LEARNING_RATE).minimize(total_loss,
                                                      var_list=total_var)

        # write logs to the disk
        flog = open(os.path.join(LOG_STORAGE_PATH, 'log.txt'), 'w')

        saver = tf.train.Saver()

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        sess = tf.Session(config=config)

        ckpt_dir = './train_results/trained_models'
        if not load_checkpoint(ckpt_dir, sess):
            sess.run(tf.global_variables_initializer())

        train_writer = tf.summary.FileWriter(SUMMARIES_FOLDER + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(SUMMARIES_FOLDER + '/test')

        fcmd = open(os.path.join(LOG_STORAGE_PATH, 'cmd.txt'), 'w')
        fcmd.write(str(FLAGS))
        fcmd.close()

        def train_one_epoch(epoch_num):
            is_training = True

            num_data = len(TRAINING_FILE_LIST)
            num_batch = num_data // batch_size
            total_loss_acc = 0.0
            cat_loss_acc = 0.0
            seg_loss_acc = 0.0
            display_mark = max([num_batch // 4, 1])
            for i in range(num_batch):
                _, total_loss_val, cat_loss_val, seg_loss_val = sess.run(
                    [step, total_loss, cat_loss, seg_loss],
                    feed_dict={is_training_ph: is_training})
                total_loss_acc += total_loss_val
                cat_loss_acc += cat_loss_val
                seg_loss_acc += seg_loss_val

                if ((i + 1) % display_mark == 0):
                    printout(
                        flog, 'Epoch %d/%d - Iter %d/%d' %
                        (epoch_num + 1, TRAINING_EPOCHES, i + 1, num_batch))
                    printout(flog,
                             'Total Loss: %f' % (total_loss_acc / (i + 1)))
                    printout(
                        flog,
                        'Classification Loss: %f' % (cat_loss_acc / (i + 1)))
                    printout(
                        flog,
                        'Segmentation Loss: %f' % (seg_loss_acc / (i + 1)))

            printout(flog,
                     '\tMean Total Loss: %f' % (total_loss_acc / num_batch))
            printout(
                flog,
                '\tMean Classification Loss: %f' % (cat_loss_acc / num_batch))
            printout(
                flog,
                '\tMean Segmentation Loss: %f' % (seg_loss_acc / num_batch))

        if not os.path.exists(MODEL_STORAGE_PATH):
            os.mkdir(MODEL_STORAGE_PATH)

        coord = tf.train.Coordinator()
        for num_thread in range(16):
            t = StoppableThread(target=load_and_enqueue,
                                args=(sess, enqueue_op, pointgrid_ph,
                                      cat_label_ph, seg_label_ph))
            t.setDaemon(True)
            t.start()
            coord.register_thread(t)

        for epoch in range(TRAINING_EPOCHES):
            printout(
                flog, '\n>>> Training for the epoch %d/%d ...' %
                (epoch + 1, TRAINING_EPOCHES))

            train_one_epoch(epoch)

            if (epoch + 1) % 1 == 0:
                cp_filename = saver.save(
                    sess,
                    os.path.join(MODEL_STORAGE_PATH,
                                 'epoch_' + str(epoch + 1) + '.ckpt'))
                printout(
                    flog, 'Successfully store the checkpoint model into ' +
                    cp_filename)

            flog.flush()
        flog.close()
示例#4
0
def train():
    with tf.device('/gpu:' + str(GPU_IDX)):
        # placeholder的作用是什么,我们现在需要用到一些参数构建graph,但这些参数要到后面才能给
        # 而且不同的阶段这些参数不同,因此不能直接将它们给定,先用placeholder占个位
        # 在后面向graph喂数据的时候再把这些数据具体输进去
        img_pl, label_seg_pl, label_cls_pl = \
            network.placeholder_inputs(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE)
        is_training_pl = tf.placeholder(tf.bool, shape=())
        print(is_training_pl)

        global_step = tf.Variable(0)
        bn_decay = get_bn_decay(global_step)

        seg_out, seg_out_former = network.seg_model(img_pl, is_training_pl,
                                                    bn_decay)
        seg_pred, loss_seg = network.get_loss(label_seg_pl, seg_out, 'seg')

        cls_out = network.cls_model(seg_out, seg_out_former, is_training_pl,
                                    bn_decay)
        cls_pred, loss_cls = network.get_loss(label_cls_pl, cls_out, 'cls')

        learning_rate = get_learning_rate(global_step)
        tf.summary.scalar('learning_rate', learning_rate)

        if OPTIMIZER == 'momentum':
            optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                   momentum='MOMENTUM')
        elif OPTIMIZER == 'adam':
            optimizer = tf.train.AdamOptimizer(learning_rate)

        train_op_seg = optimizer.minimize(loss_seg, global_step=global_step)
        train_op_cls = optimizer.minimize(loss_cls, global_step=global_step)
        # global_step: Optional 'Variable' to increment by one after the
        # variables have been updated every time
        # 所有数据集训练完一次,称为一个epoch。
        # 在一个epoch内,每训练一个batchsize,参数更新一次,称为一个iteration/step,
        # 累计在所有epoch的训练过程中,iteration的数目成为global_step

        saver = tf.train.Saver()

    # Create a session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True
    config.log_device_placement = True
    sess = tf.Session(config=config)

    init = tf.global_variables_initializer()

    sess.run(init)

    ops_seg = {
        'stage': 'seg',
        'img_pl': img_pl,
        'label_pl': label_seg_pl,
        'is_training_pl': is_training_pl,
        'pred': seg_pred,
        'loss': loss_seg,
        'train_op': train_op_seg,
        'step': global_step
    }

    ops_cls = {
        'stage': 'cls',
        'img_pl': img_pl,
        'label_pl': label_cls_pl,
        'is_training_pl': is_training_pl,
        'pred': cls_pred,
        'loss': loss_cls,
        'train_op': train_op_cls,
        'step': global_step
    }

    for e in range(MAX_EPOCH):
        log_string('----- EPOCH %03d -----' % e)
        for epoch in range(MAX_EPOCH_seg):
            log_string('----- SEG EPOCH %03d -----' % epoch)
            train_one_epoch(sess, ops_seg)
            eval_one_epoch(sess, ops_seg)

        for epoch in range(MAX_EPOCH_cls):
            log_string('----- CLS EPOCH %03d -----' % epoch)
            train_one_epoch(sess, ops_cls)
            eval_one_epoch(sess, ops_cls)

        if e % 1 == 0:
            save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt"))
            log_string("Model saved in file: %s" % save_path)
示例#5
0
def train():
    with tf.Graph().as_default():
        with tf.device('/gpu:' + str(FLAGS.gpu)):
            pointgrid_ph, seg_label_ph = placeholder_inputs()
            is_training_ph = tf.placeholder(tf.bool, shape=())

            # model
            pred_seg = model.get_model(pointgrid_ph,
                                       is_training=is_training_ph)
            total_loss, seg_loss = model.get_loss(pred_seg, seg_label_ph)

            # optimization
            total_var = tf.trainable_variables()
            step = tf.train.AdamOptimizer(
                learning_rate=LEARNING_RATE).minimize(total_loss,
                                                      var_list=total_var)

        # write logs to the disk
        flog = open(os.path.join(LOG_DIR, 'log_train.txt'), 'w')
        saver = tf.train.Saver()

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        sess = tf.Session(config=config)

        if not load_checkpoint(ckpt_dir, sess):
            sess.run(tf.global_variables_initializer())

        # Add summary writers
        train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'),
                                             sess.graph)
        test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test'))

        fcmd = open(os.path.join(LOG_DIR, 'cmd.txt'), 'w')
        fcmd.write(str(FLAGS))
        fcmd.close()

        def train_one_epoch(epoch_num, sess, train_writer):
            is_training = True
            current_data, current_label, _ = shuffle_data(
                train_data, train_label)
            num_data = train_data.shape[0]
            num_batch = num_data // BATCH_SIZE
            total_loss_acc = 0.0
            seg_loss_acc = 0.0

            for batch_idx in range(num_batch):

                start_idx = batch_idx * BATCH_SIZE
                end_idx = (batch_idx + 1) * BATCH_SIZE

                pointgrid, pointgrid_label = transfor_data(
                    current_data[start_idx:end_idx, :, :],
                    current_label[start_idx:end_idx, :])

                feed_dict = {
                    is_training_ph: is_training,
                    pointgrid_ph: pointgrid,
                    seg_label_ph: pointgrid_label
                }

                _, total_loss_val, seg_loss_val = sess.run(
                    [step, total_loss, seg_loss], feed_dict=feed_dict)
                # train_writer.add_summary(total_loss_val,seg_loss_val)
                total_loss_acc += total_loss_val
                seg_loss_acc += seg_loss_val

                if batch_idx % 100 == 0:
                    print('Current batch/total batch num: %d/%d' %
                          (batch_idx, num_batch))
                    printout(
                        flog, 'Epoch %d/%d - Iter %d/%d' %
                        (epoch_num + 1, TRAINING_EPOCHES, batch_idx + 1,
                         num_batch))
                    printout(
                        flog,
                        'Total Loss: %f' % (total_loss_acc / (batch_idx + 1)))
                    printout(
                        flog, 'Segmentation Loss: %f' % (seg_loss_acc /
                                                         (batch_idx + 1)))

            printout(flog,
                     '\tMean Total Loss: %f' % (total_loss_acc / num_batch))
            printout(
                flog,
                '\tMean Segmentation Loss: %f' % (seg_loss_acc / num_batch))

        def test_one_epoch(sess, test_writer):
            is_training = False
            current_data, current_label, _ = shuffle_data(
                test_data, test_label)
            num_data = test_data.shape[0]
            num_batch = num_data // BATCH_SIZE
            total_loss_acc = 0.0
            seg_loss_acc = 0.0
            total_correct = 0
            total_seen = 0
            total_seen_class = [0 for _ in range(model.SEG_PART)]
            total_correct_class = [0 for _ in range(model.SEG_PART)]

            for batch_idx in range(num_batch):
                if batch_idx % 100 == 0:
                    print('Current batch/total batch num: %d/%d' %
                          (batch_idx, num_batch))
                start_idx = batch_idx * BATCH_SIZE
                end_idx = (batch_idx + 1) * BATCH_SIZE

                pointgrid, pointgrid_label = transfor_data(
                    current_data[start_idx:end_idx, :, :],
                    current_label[start_idx:end_idx, :])

                feed_dict = {
                    is_training_ph: is_training,
                    pointgrid_ph: pointgrid,
                    seg_label_ph: pointgrid_label
                }

                _, total_loss_val, seg_loss_val, pred_val = sess.run(
                    [step, total_loss, seg_loss, pred_seg],
                    feed_dict=feed_dict)
                # test_writer.add_summary(step, total_loss_val, seg_loss_val)
                total_loss_acc += total_loss_val
                seg_loss_acc += seg_loss_val

                pred_val = np.argmax(pred_val, 2)
                TP = np.sum(pred_val == current_label[start_idx:end_idx])
                total_correct += TP
                total_seen += (BATCH_SIZE * model.NUM_POINT)

                for i in range(start_idx, end_idx):
                    for j in range(model.NUM_POINT):
                        l = current_label[i, j]
                        total_seen_class[l] += 1
                        total_correct_class[l] += (pred_val[i - start_idx,
                                                            j] == l)

            printout(flog,
                     'eval accuracy: %f' % (total_correct / float(total_seen)))
            printout(
                flog, 'eval avg class acc: %f' % (np.mean(
                    np.array(total_correct_class) /
                    np.array(total_seen_class, dtype=np.float))))

            printout(flog,
                     '\tMean Total Loss: %f' % (total_loss_acc / num_batch))
            printout(
                flog,
                '\tMean Segmentation Loss: %f' % (seg_loss_acc / num_batch))

        for epoch in range(TRAINING_EPOCHES):
            printout(
                flog, '\n>>> Training for the epoch %d/%d ...' %
                (epoch, TRAINING_EPOCHES))

            train_one_epoch(epoch, sess, train_writer)
            # test_one_epoch(sess, test_writer)

            if epoch % 5 == 0:
                cp_filename = saver.save(
                    sess,
                    os.path.join(LOG_DIR, 'epoch_' + str(epoch + 1) + '.ckpt'))
                printout(
                    flog, 'Successfully store the checkpoint model into ' +
                    cp_filename)

            flog.flush()
        flog.close()
示例#6
0
def run_train(lyftdata, train, config):

    use_bn = config['use_bn'] if 'use_bn' in config else False
    bn_level = config['bn_level'] if 'bn_level' in config else 0
    load_path = config['load_path'] if 'load_path' in config else ''
    save_path = config['save_path'] if 'save_path' in config else ''
    sub_path = config['sub_path'] if 'sub_path' in config else ''
    save_name = config['save_name'] if 'save_name' in config else 'save'
    lr = config['lr'] if 'lr' in config else 1e-4
    lr_decay_per_epoch = config[
        'lr_decay_per_epoch'] if 'lr_decay_per_epoch' in config else 0.2
    epochs = config['epochs'] if 'epochs' in config else 3
    cls_weight = config['cls_weight'] if 'cls_weight' in config else 10.
    reg_weight = config['reg_weight'] if 'reg_weight' in config else 50.

    train_split = config['train_split'] if 'train_split' in config else 0.9
    seed = config['seed'] if 'seed' in config else 0
    workers = config['workers'] if 'workers' in config else 1
    use_multiprocessing = config[
        'use_multiprocessing'] if 'use_multiprocessing' in config else False

    if not save_path[-1] == '/':
        save_path += '/'

    if sub_path == '':
        i = 0
        while os.path.exists(save_path + '{:05d}/'.format(i)):
            i += 1
        sub_path = '{:05d}/'.format(i)
        os.mkdir(save_path + sub_path)

    save_path += (sub_path + save_name)

    histories = Histories()
    if not 'callbacks' in config:
        callbacks = [histories]
    else:
        callbacks = config['callbacks'] + [histories]

    gen = data_generator(train, lyftdata, config=config)
    np.random.seed(seed)
    perm = np.random.permutation(gen.train.shape[0])
    train_idx = perm[0:int(train_split * gen.train.shape[0])]
    val_idx = perm[int(train_split * gen.train.shape[0]):]

    train_gen = train_generator(train_idx,
                                gen,
                                batch_size=4,
                                shuffle=True,
                                seed=None)
    val_gen = evaluation_generator(val_idx, gen, batch_size=4)

    if load_path != '':
        model = load_model(load_path)
        loss, cls_loss, reg_loss = get_loss(len(gen.inc_classes), cls_weight,
                                            reg_weight)
        model.compile(optimizer=Adam(lr),
                      loss=loss,
                      metrics=[cls_loss, reg_loss])
    else:
        model = get_model(gen.shape,
                          len(gen.inc_classes),
                          use_bn=use_bn,
                          bn_level=bn_level,
                          expand_channels=4,
                          cls_weight=cls_weight,
                          reg_weight=reg_weight,
                          optimizer=Adam(lr))

    def scheduler(epoch, lr):
        if epoch == 0:
            return lr
        else:
            return lr * lr_decay_per_epoch

    lr_scheduler = LearningRateScheduler(scheduler, verbose=1)
    callbacks.append(lr_scheduler)
    callbacks.append(SaveCheckPoints(frequency=1000, path=save_path))

    hist = model.fit_generator(train_gen,
                               epochs=epochs,
                               use_multiprocessing=use_multiprocessing,
                               workers=workers,
                               callbacks=callbacks,
                               validation_data=val_gen)

    with open(save_path + '_config.json', 'w') as f:
        json.dump(config, f)

    return model, histories, hist
示例#7
0
def train():
    with tf.Graph().as_default():
        with tf.device('/gpu:' + str(FLAGS.gpu)):
            pointgrid_ph, seg_label_ph = placeholder_inputs()
            is_training_ph = tf.placeholder(tf.bool, shape=())

            queue = tf.FIFOQueue(capacity=20*batch_size, dtypes=[tf.float32, tf.float32],\
                                                         shapes=[[model.N, model.N, model.N, model.NUM_FEATURES],\
                                                                 [model.N, model.N, model.N, model.K+1, model.NUM_SEG_PART]])
            enqueue_op = queue.enqueue([pointgrid_ph, seg_label_ph])
            dequeue_pointgrid, dequeue_seg_label = queue.dequeue_many(
                batch_size)

            # model
            pred_seg = model.get_model(dequeue_pointgrid,
                                       is_training=is_training_ph)

            # loss
            total_loss, seg_loss = model.get_loss(pred_seg, dequeue_seg_label)

            # optimization
            total_var = tf.trainable_variables()
            step = tf.train.AdamOptimizer(
                learning_rate=LEARNING_RATE).minimize(total_loss,
                                                      var_list=total_var)

        # write logs to the disk
        flog = open(os.path.join(LOG_STORAGE_PATH, 'log_train.txt'), 'w')

        saver = tf.train.Saver()

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        sess = tf.Session(config=config)

        ckpt_dir = './train_results/trained_models'
        if not load_checkpoint(ckpt_dir, sess):
            sess.run(tf.global_variables_initializer())

        train_writer = tf.summary.FileWriter(SUMMARIES_FOLDER + '/train',
                                             sess.graph)
        test_writer = tf.summary.FileWriter(SUMMARIES_FOLDER + '/test')

        fcmd = open(os.path.join(LOG_STORAGE_PATH, 'cmd.txt'), 'w')
        fcmd.write(str(FLAGS))
        fcmd.close()

        def train_one_epoch(epoch_num):
            is_training = True

            num_data = train_data.shape[0]
            num_batch = num_data // batch_size
            total_loss_acc = 0.0
            seg_loss_acc = 0.0
            display_mark = max([num_batch // 4, 1])
            for i in range(num_batch):
                _, total_loss_val, seg_loss_val = sess.run(
                    [step, total_loss, seg_loss],
                    feed_dict={is_training_ph: is_training})
                total_loss_acc += total_loss_val
                seg_loss_acc += seg_loss_val

                if ((i + 1) % display_mark == 0):
                    printout(
                        flog, 'Epoch %d/%d - Iter %d/%d' %
                        (epoch_num + 1, TRAINING_EPOCHES, i + 1, num_batch))
                    printout(flog,
                             'Total Loss: %f' % (total_loss_acc / (i + 1)))
                    printout(
                        flog,
                        'Segmentation Loss: %f' % (seg_loss_acc / (i + 1)))

            printout(flog,
                     '\tMean Total Loss: %f' % (total_loss_acc / num_batch))
            printout(
                flog,
                '\tMean Segmentation Loss: %f' % (seg_loss_acc / num_batch))

        def test_one_epoch(epoch_num):
            is_training = False
            total_loss_acc = 0.0
            seg_loss_acc = 0.0
            gt_classes = [0 for _ in range(model.NUM_CATEGORY)]
            positive_classes = [0 for _ in range(model.NUM_CATEGORY)]
            true_positive_classes = [0 for _ in range(model.NUM_CATEGORY)]
            for i in range(test_data.shape[0]):
                pc = np.squeeze(test_data[i, :, :])
                labels = np.squeeze(test_label[i, :]).astype(int)
                seg_label = model.integer_label_to_one_hot_label(labels)
                pointgrid, pointgrid_label, index = model.pc2voxel(
                    pc, seg_label)
                feed_dict = {
                    is_training_ph: is_training,
                    pointgrid_ph: pointgrid,
                    seg_label_ph: pointgrid_label
                }
                total_loss_val, seg_loss_val, pred_seg_val = sess.run(
                    [total_loss, seg_loss, pred_seg], feed_dict=feed_dict)
                total_loss_acc += total_loss_val
                seg_loss_acc += seg_loss_val

                pred_seg_val = pred_seg_val[0, :, :, :, :, :]
                pred_point_label = model.populateOneHotSegLabel(
                    pc, pred_seg_val, index)
                for j in range(pred_point_label.shape[0]):
                    gt_l = int(labels[j])
                    pred_l = int(pred_point_label[j])
                    gt_classes[gt_l - 1] += 1
                    positive_classes[pred_l - 1] += 1
                    true_positive_classes[gt_l - 1] += int(gt_l == pred_l)

            printout(flog, 'gt_l count:{}'.format(gt_classes))
            printout(flog,
                     'positive_classes count:{}'.format(positive_classes))
            printout(
                flog,
                'true_positive_classes count:{}'.format(true_positive_classes))

            iou_list = []
            for i in range(model.SEG_PART):
                iou = true_positive_classes[i] / float(
                    gt_classes[i] + positive_classes[i] -
                    true_positive_classes[i])
                iou_list.append(iou)
            printout(flog, 'IOU:{}'.format(iou_list))
            printout(
                flog, 'ACC:{}'.format(
                    sum(true_positive_classes) / sum(positive_classes)))
            printout(flog,
                     'mIOU:{}'.format(sum(iou_list) / float(model.SEG_PART)))
            printout(
                flog, '\tMean Total Loss: %f' %
                (total_loss_acc / test_data.shape[0]))
            printout(
                flog, '\tMean Segmentation Loss: %f' %
                (seg_loss_acc / test_data.shape[0]))

        if not os.path.exists(MODEL_STORAGE_PATH):
            os.mkdir(MODEL_STORAGE_PATH)

        coord = tf.train.Coordinator()
        for num_thread in range(16):
            t = StoppableThread(target=load_and_enqueue,
                                args=(sess, enqueue_op, pointgrid_ph,
                                      seg_label_ph))
            t.setDaemon(True)
            t.start()
            coord.register_thread(t)

        for epoch in range(TRAINING_EPOCHES):
            printout(
                flog, '\n>>> Training for the epoch %d/%d ...' %
                (epoch + 1, TRAINING_EPOCHES))

            train_one_epoch(epoch)
            # test_one_epoch(epoch)

            if (epoch + 1) % 1 == 0:
                cp_filename = saver.save(
                    sess,
                    os.path.join(MODEL_STORAGE_PATH,
                                 'epoch_' + str(epoch + 1) + '.ckpt'))
                printout(
                    flog, 'Successfully store the checkpoint model into ' +
                    cp_filename)

            flog.flush()
        flog.close()